Merge branch 'main' into arknights

This commit is contained in:
felinae98
2021-04-27 10:50:49 +08:00
15 changed files with 468 additions and 392 deletions
@@ -7,12 +7,16 @@ from nonebot.rule import to_me
from nonebot.typing import T_State
from .config import Config, NoSuchSubscribeException
from .platform import platform_manager
from .platform.utils import check_sub_target
from .platform import platform_manager, check_sub_target
from .send import send_msgs
from .utils import parse_text
from .types import Target
common_platform = [p.platform_name for p in \
filter(lambda platform: platform.enabled and platform.is_common,
platform_manager.values())
]
help_match = on_command('help', rule=to_me(), priority=5)
@help_match.handle()
async def send_help(bot: Bot, event: Event, state: T_State):
@@ -20,7 +24,32 @@ async def send_help(bot: Bot, event: Event, state: T_State):
await help_match.finish(Message(await parse_text(message)))
add_sub = on_command("添加订阅", rule=to_me(), permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, priority=5)
@add_sub.got('platform', '请输入想要订阅的平台,目前支持:{}'.format(', '.join(platform_manager.keys())))
@add_sub.handle()
async def add_sub_handle_platform(bot: Bot, event: Event, state: T_State):
if 'platform' in state:
return
await bot.send(event=event, message='请输入想要订阅的平台,目前支持:\n' +
''.join(['{}{}\n'.format(platform_name, platform_manager[platform_name].name) \
for platform_name in common_platform]) +
'要查看全部平台请输入:“全部”'
)
await add_sub.pause()
@add_sub.handle()
async def add_sub_parse_platform(bot: Bot, event: Event, state: T_State):
if 'platform' in state:
return
platform = str(event.get_message()).strip()
if platform == '全部':
message = '全部平台' + \
'\n'.join(['{}{}'.format(platform_name, platform.name) \
for platform_name, platform in platform_manager.items()])
await add_sub.reject(message)
elif platform in platform_manager:
state['platform'] = platform
else:
await add_sub.reject('平台输入错误')
# @add_sub.got('platform', '请输入想要订阅的平台,目前支持:{}'.format(', '.join(platform_manager.keys())))
# @add_sub.got('id', '请输入订阅用户的id,详情查阅https://github.com/felinae98/nonebot-hk-reporter')
@add_sub.handle()
async def add_sub_handle_id(bot: Bot, event: Event, state: T_State):
@@ -2,6 +2,7 @@ from .bilibili import Bilibili
from .rss import Rss
from .weibo import Weibo
from .wechat import Wechat
from .utils import check_sub_target
from .utils import check_sub_target, fetch_and_send
from .platform import PlatformNoTarget
from .utils import platform_manager
@@ -19,7 +19,11 @@ class Arknights(PlatformNoTarget):
categories = {}
platform_name = 'arknights'
name = '明日方舟游戏内公告'
enable_tag = False
enabled = True
is_common = False
schedule_interval = 30
@staticmethod
async def get_account_name(_: Target) -> str:
@@ -19,6 +19,10 @@ class Bilibili(Platform):
}
platform_name = 'bilibili'
enable_tag = False
enabled = True
is_common = True
schedule_interval = 10
name = 'B站'
@staticmethod
async def get_account_name(target: Target) -> Optional[str]:
@@ -12,6 +12,10 @@ class MonsterSiren(PlatformNoTarget):
categories = {}
platform_name = 'monster-siren'
enable_tag = False
enabled = True
is_common = False
schedule_interval = 30
name = '塞壬唱片官网新闻'
@staticmethod
async def get_account_name(_) -> str:
@@ -6,182 +6,55 @@ import httpx
from nonebot import logger
from ..config import Config
from ..plugin_config import plugin_config
from ..post import Post
from ..types import Category, RawPost, Tag, Target, User
from ..utils import Singleton
class CategoryNotSupport(Exception):
"raise in get_category, when post category is not supported"
pass
class PlatformProto(metaclass=Singleton):
class RegistryMeta(type):
def __new__(cls, name, bases, namespace, **kwargs):
if name not in ['PlatformProto', 'Platform', 'PlatformNoTarget'] and \
'platform_name' not in namespace:
raise TypeError('Platform has no `platform_name`')
return super().__new__(cls, name, bases, namespace, **kwargs)
def __init__(cls, name, bases, namespace, **kwargs):
if not hasattr(cls, 'registory'):
# this is the base class
cls.registory = []
elif name not in ['Platform', 'PlatformNoTarget']:
# this is the subclass
cls.registory.append(cls)
super().__init__(name, bases, namespace, **kwargs)
class PlatformProto(metaclass=RegistryMeta):
categories: dict[Category, str]
reverse_category: dict[str, Category]
has_target: bool
platform_name: str
name: str
enable_tag: bool
cache: dict[Any, Post]
enabled: bool
is_common: bool
schedule_interval: int
async def fetch_new_post(self, target: Target, users: list[User]) -> list[tuple[User, list[Post]]]:
...
@staticmethod
async def get_account_name(target: Target) -> Optional[str]:
...
class Platform(PlatformProto):
"platform with target(account), like weibo, bilibili"
categories: dict[Category, str]
has_target: bool = True
platform_name: str
enable_tag: bool
def __init__(self):
self.exists_posts = defaultdict(set)
self.inited = dict()
self.reverse_category = {}
self.cache: dict[Any, Post] = {}
for key, val in self.categories.items():
self.reverse_category[val] = key
@staticmethod
async def get_account_name(target: Target) -> Optional[str]:
"Given a tareget, return the username(name) of the target"
raise NotImplementedError()
async def get_sub_list(self, target: Target) -> list[RawPost]:
"Get post list of the given target"
raise NotImplementedError()
def get_id(self, post: RawPost) -> Any:
"Get post id of given RawPost"
raise NotImplementedError()
def get_date(self, post: RawPost) -> Optional[int]:
"Get post timestamp and return, return None if can't get the time"
raise NotImplementedError()
def get_category(self, post: RawPost) -> Optional[Category]:
"Return category of given Rawpost"
raise NotImplementedError()
def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]:
"Return Tag list of given RawPost"
raise NotImplementedError()
async def parse(self, raw_post: RawPost) -> Post:
"parse RawPost into post"
raise NotImplementedError()
def filter_platform_custom(self, post: RawPost) -> bool:
raise NotImplementedError()
async def _parse_with_cache(self, post: RawPost) -> Post:
post_id = self.get_id(post)
if post_id not in self.cache:
self.cache[post_id] = await self.parse(post)
return self.cache[post_id]
async def filter_common(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]:
if not self.inited.get(target, False):
# target not init
for raw_post in raw_post_list:
post_id = self.get_id(raw_post)
self.exists_posts[target].add(post_id)
logger.info('init {}-{} with {}'.format(self.platform_name, target, self.exists_posts[target]))
self.inited[target] = True
return []
res: list[RawPost] = []
for raw_post in raw_post_list:
post_id = self.get_id(raw_post)
if post_id in self.exists_posts[target]:
continue
if (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60:
continue
try:
if not self.filter_platform_custom(raw_post):
continue
except NotImplementedError:
pass
try:
self.get_category(raw_post)
except CategoryNotSupport:
continue
except NotImplementedError:
pass
res.append(raw_post)
self.exists_posts[target].add(post_id)
return res
async def filter_user_custom(self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]) -> list[RawPost]:
res: list[RawPost] = []
for raw_post in raw_post_list:
if self.categories:
cat = self.get_category(raw_post)
if cats and cat not in cats:
continue
if self.enable_tag and tags:
flag = False
post_tags = self.get_tags(raw_post)
for tag in post_tags:
if tag in tags:
flag = True
break
if not flag:
continue
res.append(raw_post)
return res
async def fetch_new_post(self, target: Target, users: list[User]) -> list[tuple[User, list[Post]]]:
try:
config = Config()
post_list = await self.get_sub_list(target)
new_posts = await self.filter_common(target, post_list)
res: list[tuple[User, list[Post]]] = []
if not new_posts:
return []
else:
for post in new_posts:
logger.info('fetch new post from {} {}: {}'.format(self.platform_name, target, self.get_id(post)))
for user in users:
required_tags = config.get_sub_tags(self.platform_name, target, user.user_type, user.user) if self.enable_tag else []
cats = config.get_sub_category(self.platform_name, target, user.user_type, user.user)
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
user_post: list[Post] = []
for raw_post in user_raw_post:
user_post.append(await self._parse_with_cache(raw_post))
res.append((user, user_post))
self.cache = {}
return res
except httpx.RequestError as err:
logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url))
return []
class PlatformNoTarget(PlatformProto):
categories: dict[Category, str]
has_target = False
platform_name: str
enable_tag: bool
def __init__(self):
self.exists_posts = set()
self.inited = False
self.reverse_category = {}
self.cache: dict[Any, Post] = {}
for key, val in self.categories.items():
self.reverse_category[val] = key
@staticmethod
async def get_account_name(target: Target) -> Optional[str]:
"return the username(name) of the target"
raise NotImplementedError()
async def get_sub_list(self) -> list[RawPost]:
"Get post list of the given target"
raise NotImplementedError()
def get_id(self, post: RawPost) -> Any:
"Get post id of given RawPost"
raise NotImplementedError()
@@ -219,21 +92,14 @@ class PlatformNoTarget(PlatformProto):
retry_times -= 1
return self.cache[post_id]
async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]:
if not self.inited:
# target not init
for raw_post in raw_post_list:
post_id = self.get_id(raw_post)
self.exists_posts.add(post_id)
logger.info('init {} with {}'.format(self.platform_name, self.exists_posts))
self.inited = True
return []
res: list[RawPost] = []
def _do_filter_common(self, raw_post_list: list[RawPost], exists_posts_set: set) -> list[RawPost]:
res = []
for raw_post in raw_post_list:
post_id = self.get_id(raw_post)
if post_id in self.exists_posts:
if post_id in exists_posts_set:
continue
if (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60:
if (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60 and \
plugin_config.hk_reporter_init_filter:
continue
try:
if not self.filter_platform_custom(raw_post):
@@ -247,7 +113,7 @@ class PlatformNoTarget(PlatformProto):
except NotImplementedError:
pass
res.append(raw_post)
self.exists_posts.add(post_id)
exists_posts_set.add(post_id)
return res
async def filter_user_custom(self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]) -> list[RawPost]:
@@ -269,6 +135,94 @@ class PlatformNoTarget(PlatformProto):
res.append(raw_post)
return res
class Platform(PlatformProto):
"platform with target(account), like weibo, bilibili"
categories: dict[Category, str]
has_target: bool = True
platform_name: str
enable_tag: bool
def __init__(self):
self.exists_posts = defaultdict(set)
self.inited = dict()
self.reverse_category = {}
self.cache: dict[Any, Post] = {}
for key, val in self.categories.items():
self.reverse_category[val] = key
async def get_sub_list(self, target: Target) -> list[RawPost]:
"Get post list of the given target"
raise NotImplementedError()
async def filter_common(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]:
if not self.inited.get(target, False) and plugin_config.hk_reporter_init_filter:
# target not init
for raw_post in raw_post_list:
post_id = self.get_id(raw_post)
self.exists_posts[target].add(post_id)
logger.info('init {}-{} with {}'.format(self.platform_name, target, self.exists_posts[target]))
self.inited[target] = True
return []
return self._do_filter_common(raw_post_list, self.exists_posts[target])
async def fetch_new_post(self, target: Target, users: list[User]) -> list[tuple[User, list[Post]]]:
try:
config = Config()
post_list = await self.get_sub_list(target)
new_posts = await self.filter_common(target, post_list)
res: list[tuple[User, list[Post]]] = []
if not new_posts:
return []
else:
for post in new_posts:
logger.info('fetch new post from {} {}: {}'.format(self.platform_name, target, self.get_id(post)))
for user in users:
required_tags = config.get_sub_tags(self.platform_name, target, user.user_type, user.user) if self.enable_tag else []
cats = config.get_sub_category(self.platform_name, target, user.user_type, user.user)
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
user_post: list[Post] = []
for raw_post in user_raw_post:
user_post.append(await self._parse_with_cache(raw_post))
res.append((user, user_post))
self.cache = {}
return res
except httpx.RequestError as err:
logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url))
return []
class PlatformNoTarget(PlatformProto):
categories: dict[Category, str]
has_target = False
platform_name: str
enable_tag: bool
async def get_sub_list(self) -> list[RawPost]:
"Get post list of the given target"
raise NotImplementedError()
def __init__(self):
self.exists_posts = set()
self.inited = False
self.reverse_category = {}
self.cache: dict[Any, Post] = {}
for key, val in self.categories.items():
self.reverse_category[val] = key
async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]:
if not self.inited and plugin_config.hk_reporter_init_filter:
# target not init
for raw_post in raw_post_list:
post_id = self.get_id(raw_post)
self.exists_posts.add(post_id)
logger.info('init {} with {}'.format(self.platform_name, self.exists_posts))
self.inited = True
return []
return self._do_filter_common(raw_post_list, self.exists_posts)
async def fetch_new_post(self, _: Target, users: list[User]) -> list[tuple[User, list[Post]]]:
try:
config = Config()
@@ -14,6 +14,10 @@ class Rss(Platform):
categories = {}
enable_tag = False
platform_name = 'rss'
name = "Rss"
enabled = True
is_common = True
schedule_interval = 30
@staticmethod
async def get_account_name(target: Target) -> Optional[str]:
@@ -2,12 +2,6 @@ import nonebot
from nonebot import logger
from collections import defaultdict
from typing import Type
from .arknights import Arknights
from .weibo import Weibo
from .bilibili import Bilibili
from .monster_siren import MonsterSiren
from .rss import Rss
from .wechat import Wechat
from .platform import PlatformProto
from ..config import Config
from ..post import Post
@@ -17,12 +11,8 @@ async def check_sub_target(target_type, target):
return await platform_manager[target_type].get_account_name(target)
platform_manager: dict[str, PlatformProto] = {
'bilibili': Bilibili(),
'weibo': Weibo(),
'rss': Rss(),
'arknights': Arknights(),
'monster-siren': MonsterSiren(),
# 'wechat': Wechat(),
obj.platform_name: obj() for obj in \
filter(lambda platform: platform.enabled, PlatformProto.registory)
}
async def fetch_and_send(target_type: str):
@@ -17,6 +17,9 @@ class Wechat(Platform):
categories = {}
enable_tag = False
platform_name = 'wechat'
enabled = False
is_common = False
name = '微信公众号'
@classmethod
def _get_query_url(cls, target: Target):
@@ -21,6 +21,10 @@ class Weibo(Platform):
}
enable_tag = False
platform_name = 'weibo'
name = '新浪微博'
enabled = True
is_common = True
schedule_interval = 10
def __init__(self):
self.top : dict[Target, RawPost] = dict()
@@ -6,6 +6,7 @@ class PlugConfig(BaseSettings):
hk_reporter_config_path: str = ""
hk_reporter_use_pic: bool = False
hk_reporter_use_local: bool = False
hk_reporter_init_filter: bool = True
class Config:
extra = 'ignore'
+14 -26
View File
@@ -1,35 +1,23 @@
from nonebot import require
from nonebot import get_driver, logger
from .send import do_send_msgs
from .platform.utils import fetch_and_send
from .platform import fetch_and_send, platform_manager
from apscheduler.schedulers.asyncio import AsyncIOScheduler
scheduler: AsyncIOScheduler = require('nonebot_plugin_apscheduler').scheduler
scheduler = AsyncIOScheduler()
@scheduler.scheduled_job('interval', seconds=10)
async def weibo_check():
await fetch_and_send('weibo')
async def _start():
scheduler.configure({"apscheduler.timezone": "Asia/Shanghai"})
scheduler.start()
@scheduler.scheduled_job('interval', seconds=10)
async def bilibili_check():
await fetch_and_send('bilibili')
get_driver().on_startup(_start)
@scheduler.scheduled_job('interval', seconds=30)
async def rss_check():
await fetch_and_send('rss')
@scheduler.scheduled_job('interval', seconds=30)
async def arknights_check():
await fetch_and_send('arknights')
@scheduler.scheduled_job('interval', seconds=30)
async def monster_siren_check():
await fetch_and_send('monster-siren')
# @scheduler.scheduled_job('interval', seconds=30)
# async def wechat_check():
# await fetch_and_send('wechat')
for platform_name, platform in platform_manager.items():
if isinstance(platform.schedule_interval, int):
logger.info(f'start scheduler for {platform_name} with interval {platform.schedule_interval}')
scheduler.add_job(
fetch_and_send, 'interval', seconds=platform.schedule_interval,
args=(platform_name,))
@scheduler.scheduled_job('interval', seconds=1)
async def _():
async def _send_msgs():
await do_send_msgs()