add notargetgroup

This commit is contained in:
felinae98 2021-07-31 12:43:19 +08:00
parent cf6ad6bc82
commit 7b6226a833
No known key found for this signature in database
GPG Key ID: 00C8B010587FF610
7 changed files with 127 additions and 12 deletions

View File

@ -22,8 +22,7 @@ class Arknights(NewMessage, NoTargetMixin):
schedule_type = 'interval' schedule_type = 'interval'
schedule_kw = {'seconds': 30} schedule_kw = {'seconds': 30}
@staticmethod async def get_target_name(self, _: Target) -> str:
async def get_target_name(_: Target) -> str:
return '明日方舟游戏内公告' return '明日方舟游戏内公告'
async def get_sub_list(self, _) -> list[RawPost]: async def get_sub_list(self, _) -> list[RawPost]:

View File

@ -24,8 +24,7 @@ class Bilibili(NewMessage, TargetMixin):
schedule_kw = {'seconds': 10} schedule_kw = {'seconds': 10}
name = 'B站' name = 'B站'
@staticmethod async def get_target_name(self, target: Target) -> Optional[str]:
async def get_target_name(target: Target) -> Optional[str]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
res = await client.get('https://api.bilibili.com/x/space/acc/info', params={'mid': target}) res = await client.get('https://api.bilibili.com/x/space/acc/info', params={'mid': target})
res_data = json.loads(res.text) res_data = json.loads(res.text)

View File

@ -1,5 +1,7 @@
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce
import time import time
from typing import Any, Collection, Optional, Literal from typing import Any, Collection, Optional, Literal
@ -230,9 +232,8 @@ class Platform(PlatformNameMixin, UserCustomFilterMixin, base=True):
enabled: bool enabled: bool
name: str name: str
@staticmethod
@abstractmethod @abstractmethod
async def get_target_name(target: Target) -> Optional[str]: async def get_target_name(self, target: Target) -> Optional[str]:
... ...
@abstractmethod @abstractmethod
@ -299,3 +300,50 @@ class StatusChange(
except httpx.RequestError as err: except httpx.RequestError as err:
logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url))
return [] return []
class NoTargetGroup(
Platform,
NoTargetMixin,
UserCustomFilterMixin,
abstract=True
):
enable_tag = False
DUMMY_STR = '_DUMMY'
enabled = True
class PlatformProto(Platform, NoTargetMixin, UserCustomFilterMixin, abstract=True):
...
def __init__(self, platform_list: list[PlatformProto]):
self.platform_list = platform_list
name = self.DUMMY_STR
self.categories = {}
categories_keys = set()
self.schedule_type = platform_list[0].schedule_type
self.schedule_kw = platform_list[0].schedule_kw
for platform in platform_list:
if name == self.DUMMY_STR:
name = platform.name
elif name != platform.name:
raise RuntimeError('Platform name for {} not fit'.format(self.platform_name))
platform_category_key_set = set(platform.categories.keys())
if platform_category_key_set & categories_keys:
raise RuntimeError('Platform categories for {} duplicate'.format(self.platform_name))
categories_keys |= platform_category_key_set
self.categories.update(platform.categories)
if platform.schedule_kw != self.schedule_kw or platform.schedule_type != self.schedule_type:
raise RuntimeError('Platform scheduler for {} not fit'.format(self.platform_name))
self.name = name
self.is_common = platform_list[0].is_common
super().__init__()
async def get_target_name(self, _):
return await self.platform_list[0].get_target_name(_)
async def fetch_new_post(self, target, users):
res = defaultdict(list)
for platform in self.platform_list:
platform_res = await platform.fetch_new_post(target, users)
for user, posts in platform_res:
res[user].extend(posts)
return [[key, val] for key, val in res.items()]

View File

@ -20,8 +20,7 @@ class Rss(NewMessage, TargetMixin):
schedule_type = 'interval' schedule_type = 'interval'
schedule_kw = {'seconds': 30} schedule_kw = {'seconds': 30}
@staticmethod async def get_target_name(self, target: Target) -> Optional[str]:
async def get_target_name(target: Target) -> Optional[str]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
res = await client.get(target, timeout=10.0) res = await client.get(target, timeout=10.0)
feed = feedparser.parse(res.text) feed = feedparser.parse(res.text)

View File

@ -27,8 +27,7 @@ class Weibo(NewMessage, TargetMixin):
schedule_type = 'interval' schedule_type = 'interval'
schedule_kw = {'seconds': 10} schedule_kw = {'seconds': 10}
@staticmethod async def get_target_name(self, target: Target) -> Optional[str]:
async def get_target_name(target: Target) -> Optional[str]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
param = {'containerid': '100505' + target} param = {'containerid': '100505' + target}
res = await client.get('https://m.weibo.cn/api/container/getIndex', params=param) res = await client.get('https://m.weibo.cn/api/container/getIndex', params=param)

View File

@ -6,7 +6,7 @@ Target = NewType('Target', str)
Category = NewType('Category', int) Category = NewType('Category', int)
Tag = NewType('Tag', str) Tag = NewType('Tag', str)
@dataclass @dataclass(eq=True, frozen=True)
class User: class User:
user: str user: str
user_type: str user_type: str

View File

@ -131,7 +131,8 @@ def mock_platform_no_target(plugin_module: 'nonebot_hk_reporter'):
name = 'Mock Platform' name = 'Mock Platform'
enabled = True enabled = True
is_common = True is_common = True
schedule_interval = 10 schedule_type = 'interval'
schedule_kw = {'seconds': 30}
enable_tag = True enable_tag = True
categories = { categories = {
1: '转发', 1: '转发',
@ -172,6 +173,62 @@ def mock_platform_no_target(plugin_module: 'nonebot_hk_reporter'):
return MockPlatform() return MockPlatform()
@pytest.fixture
def mock_platform_no_target_2(plugin_module: 'nonebot_hk_reporter'):
class MockPlatform(plugin_module.platform.platform.NewMessage,
plugin_module.platform.platform.NoTargetMixin):
platform_name = 'mock_platform'
name = 'Mock Platform'
enabled = True
schedule_type = 'interval'
schedule_kw = {'seconds': 30}
is_common = True
enable_tag = True
categories = {
4: 'leixing4',
5: 'leixing5',
}
def __init__(self):
self.sub_index = 0
super().__init__()
@staticmethod
async def get_target_name(_: 'Target'):
return 'MockPlatform'
def get_id(self, post: 'RawPost') -> Any:
return post['id']
def get_date(self, raw_post: 'RawPost') -> float:
return raw_post['date']
def get_tags(self, raw_post: 'RawPost') -> list['Tag']:
return raw_post['tags']
def get_category(self, raw_post: 'RawPost') -> 'Category':
return raw_post['category']
async def parse(self, raw_post: 'RawPost') -> 'Post':
return plugin_module.post.Post('mock_platform_2', raw_post['text'], 'http://t.tt/' + str(self.get_id(raw_post)), target_name='Mock')
async def get_sub_list(self, _: 'Target'):
list_1 = [
{'id': 5, 'text': 'p5', 'date': now, 'tags': ['tag1'], 'category': 4}
]
list_2 = list_1 + [
{'id': 6, 'text': 'p6', 'date': now, 'tags': ['tag1'], 'category': 4},
{'id': 7, 'text': 'p7', 'date': now, 'tags': ['tag2'], 'category': 5},
]
if self.sub_index == 0:
self.sub_index += 1
return list_1
else:
return list_2
return MockPlatform()
@pytest.fixture @pytest.fixture
def mock_status_change(plugin_module: 'nonebot_hk_reporter'): def mock_status_change(plugin_module: 'nonebot_hk_reporter'):
class MockPlatform(plugin_module.platform.platform.StatusChange, class MockPlatform(plugin_module.platform.platform.StatusChange,
@ -300,3 +357,17 @@ async def test_status_change(mock_status_change, user_info_factory):
assert(len(res3[1][1]) == 0) assert(len(res3[1][1]) == 0)
res4 = await mock_status_change.fetch_new_post('dummy', [user_info_factory(lambda _: [1,2], lambda _: [])]) res4 = await mock_status_change.fetch_new_post('dummy', [user_info_factory(lambda _: [1,2], lambda _: [])])
assert(len(res4) == 0) assert(len(res4) == 0)
@pytest.mark.asyncio
async def test_group(plugin_module: 'nonebot_hk_reporter', mock_platform_no_target, mock_platform_no_target_2, user_info_factory):
group_platform = plugin_module.platform.platform.NoTargetGroup([mock_platform_no_target, mock_platform_no_target_2])
res1 = await group_platform.fetch_new_post('dummy', [user_info_factory(lambda _: [1,4], lambda _: [])])
assert(len(res1) == 0)
res2 = await group_platform.fetch_new_post('dummy', [user_info_factory(lambda _: [1,4], lambda _: [])])
assert(len(res2) == 1)
posts = res2[0][1]
assert(len(posts) == 2)
id_set_2 = set(map(lambda x: x.text, posts))
assert('p2' in id_set_2 and 'p6' in id_set_2)
res3 = await group_platform.fetch_new_post('dummy', [user_info_factory(lambda _: [1,4], lambda _: [])])
assert(len(res3) == 0)