diff --git a/.gitignore b/.gitignore index 256d649..e2bad75 100644 --- a/.gitignore +++ b/.gitignore @@ -180,3 +180,4 @@ tags # End of https://www.toptal.com/developers/gitignore/api/python,linux,vim data/* .env.* +.vim/* diff --git a/README.md b/README.md index 750f8ba..c3a3965 100644 --- a/README.md +++ b/README.md @@ -33,9 +33,9 @@ ### 命令 所有命令都需要@bot触发 -* 添加订阅(仅管理员和群主):`添加订阅 平台代码 uid` +* 添加订阅(仅管理员和群主):`添加订阅` * 查询订阅:`查询订阅` -* 删除订阅(仅管理员和群主):`删除订阅 平台代码 uid` +* 删除订阅(仅管理员和群主):`删除订阅` 平台代码包含:weibo,bilibili,rss
diff --git a/pyproject.toml b/pyproject.toml index f1ee4f6..462ce0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "nonebot-hk-reporter" -version = "0.1.1" +version = "0.2.1" description = "Subscribe message from social medias" authors = ["felinae98 "] license = "MIT" diff --git a/src/plugins/hk_reporter/__init__.py b/src/plugins/hk_reporter/__init__.py index 23b7c12..210817a 100644 --- a/src/plugins/hk_reporter/__init__.py +++ b/src/plugins/hk_reporter/__init__.py @@ -1,7 +1,4 @@ import nonebot -from .plugin_config import PlugConfig -global_config = nonebot.get_driver().config -plugin_config = PlugConfig(**global_config.dict()) from . import config_manager from . import config diff --git a/src/plugins/hk_reporter/config.py b/src/plugins/hk_reporter/config.py index a935e1b..84ea072 100644 --- a/src/plugins/hk_reporter/config.py +++ b/src/plugins/hk_reporter/config.py @@ -1,10 +1,14 @@ -from .utils import Singleton, supported_target_type -from . import plugin_config -from os import path -import nonebot -from tinydb import TinyDB, Query from collections import defaultdict +from os import path import os +from typing import DefaultDict + +import nonebot +from tinydb import Query, TinyDB + +from .plugin_config import plugin_config +from .types import User +from .utils import Singleton, supported_target_type def get_config_path() -> str: @@ -25,30 +29,35 @@ class NoSuchSubscribeException(Exception): class Config(metaclass=Singleton): - migrate_version = 1 + migrate_version = 2 def __init__(self): self.db = TinyDB(get_config_path(), encoding='utf-8') self.kv_config = self.db.table('kv') self.user_target = self.db.table('user_target') self.target_user_cache = {} + self.target_user_cat_cache = {} + self.target_user_tag_cache = {} self.target_list = {} - self.next_index = defaultdict(lambda: 0) + self.next_index: DefaultDict[str, int] = defaultdict(lambda: 0) - def add_subscribe(self, user, user_type, target, target_name, target_type): + def add_subscribe(self, user, user_type, target, target_name, target_type, cats, tags): user_query = Query() query = (user_query.user == user) & (user_query.user_type == user_type) if (user_data := self.user_target.get(query)): # update subs: list = user_data.get('subs', []) - subs.append({"target": target, "target_type": target_type, 'target_name': target_name}) + subs.append({"target": target, "target_type": target_type, 'target_name': target_name, 'cats': cats, 'tags': tags}) self.user_target.update({"subs": subs}, query) else: # insert - self.user_target.insert({'user': user, 'user_type': user_type, 'subs': [{'target': target, 'target_type': target_type, 'target_name': target_name}]}) + self.user_target.insert({ + 'user': user, 'user_type': user_type, + 'subs': [{'target': target, 'target_type': target_type, 'target_name': target_name, 'cats': cats, 'tags': tags }] + }) self.update_send_cache() - def list_subscribe(self, user, user_type): + def list_subscribe(self, user, user_type): query = Query() return self.user_target.get((query.user == user) & (query.user_type ==user_type))['subs'] @@ -68,16 +77,28 @@ class Config(metaclass=Singleton): def update_send_cache(self): res = {target_type: defaultdict(list) for target_type in supported_target_type} + cat_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} + tag_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} # res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} for user in self.user_target.all(): for sub in user.get('subs', []): if not sub.get('target_type') in supported_target_type: continue - res[sub['target_type']][sub['target']].append({"user": user['user'], "user_type": user['user_type']}) + res[sub['target_type']][sub['target']].append(User(user['user'], user['user_type'])) + cat_res[sub['target_type']][sub['target']]['{}-{}'.format(user['user_type'], user['user'])] = sub['cats'] + tag_res[sub['target_type']][sub['target']]['{}-{}'.format(user['user_type'], user['user'])] = sub['tags'] self.target_user_cache = res + self.target_user_cat_cache = cat_res + self.target_user_tag_cache = tag_res for target_type in self.target_user_cache: self.target_list[target_type] = list(self.target_user_cache[target_type].keys()) + def get_sub_category(self, target_type, target, user_type, user): + return self.target_user_cat_cache[target_type][target]['{}-{}'.format(user_type, user)] + + def get_sub_tags(self, target_type, target, user_type, user): + return self.target_user_tag_cache[target_type][target]['{}-{}'.format(user_type, user)] + def get_next_target(self, target_type): # FIXME 插入或删除target后对队列的影响(但是并不是大问题 if not self.target_list[target_type]: @@ -92,7 +113,19 @@ def start_up(): if not (search_res := config.kv_config.search(Query().name=="version")): config.kv_config.insert({"name": "version", "value": config.migrate_version}) elif search_res[0].get("value") < config.migrate_version: - pass + query = Query() + version_query = (query.name == 'version') + cur_version = search_res[0].get("value") + if cur_version == 1: + cur_version = 2 + for user_conf in config.user_target.all(): + conf_id = user_conf.doc_id + subs = user_conf['subs'] + for sub in subs: + sub['cats'] = [] + sub['tags'] = [] + config.user_target.update({'subs': subs}, doc_ids=[conf_id]) + config.kv_config.update({"value": config.migrate_version}, version_query) # do migration config.update_send_cache() diff --git a/src/plugins/hk_reporter/config_manager.py b/src/plugins/hk_reporter/config_manager.py index 3fcbc7b..1375627 100644 --- a/src/plugins/hk_reporter/config_manager.py +++ b/src/plugins/hk_reporter/config_manager.py @@ -1,53 +1,141 @@ +from nonebot import logger, on_command +from nonebot.adapters.cqhttp import Bot, Event, GroupMessageEvent +from nonebot.adapters.cqhttp.message import Message +from nonebot.adapters.cqhttp.permission import GROUP_ADMIN, GROUP_MEMBER, GROUP_OWNER +from nonebot.permission import Permission, SUPERUSER from nonebot.rule import to_me from nonebot.typing import T_State -from nonebot.adapters.cqhttp import Bot, Event, GroupMessageEvent -from nonebot.permission import Permission -from nonebot.adapters.cqhttp.permission import GROUP_ADMIN, GROUP_MEMBER, GROUP_OWNER -from nonebot import on_command -from .platform.utils import check_sub_target from .config import Config, NoSuchSubscribeException -from .utils import parse_text +from .platform import platform_manager +from .platform.utils import check_sub_target from .send import send_msgs +from .utils import parse_text -add_sub = on_command("添加订阅", rule=to_me(), permission=GROUP_ADMIN | GROUP_OWNER, priority=5) +help_match = on_command('help', rule=to_me(), priority=5) +@help_match.handle() +async def send_help(bot: Bot, event: Event, state: T_State): + message = '使用方法:\n@bot 添加订阅(仅管理员)\n@bot 查询订阅\n@bot 删除订阅(仅管理员)' + await help_match.finish(Message(await parse_text(message))) + +add_sub = on_command("添加订阅", rule=to_me(), permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, priority=5) +@add_sub.got('platform', '请输入想要订阅的平台,目前支持:{}'.format(', '.join(platform_manager.keys()))) +# @add_sub.got('id', '请输入订阅用户的id,详情查阅https://github.com/felinae98/nonebot-hk-reporter') @add_sub.handle() -async def _(bot: Bot, event: Event, state: T_State): - args = str(event.get_message()).strip().split() - if len(args) != 2: - await add_sub.finish("使用方法为: 添加订阅 平台 id") +async def add_sub_handle_id(bot: Bot, event: Event, state: T_State): + if not platform_manager[state['platform']].has_target or 'id' in state: return - target_type, target = args - if name := await check_sub_target(target_type, target): - config: Config = Config() - config.add_subscribe(event.group_id, "group", target, name, target_type) - await add_sub.finish("成功添加 {}".format(name)) + await bot.send(event=event, message='请输入订阅用户的id,详情查阅https://github.com/felinae98/nonebot-hk-reporter') + await add_sub.pause() + +@add_sub.handle() +async def add_sub_parse_id(bot: Bot, event: Event, state: T_State): + if not platform_manager[state['platform']].has_target or 'id' in state: + return + target = str(event.get_message()).strip() + name = await check_sub_target(state['platform'], target) + if not name: + await add_sub.reject('id输入错误') + state['id'] = target + state['name'] = name + +@add_sub.handle() +async def add_sub_handle_cat(bot: Bot, event: Event, state: T_State): + if not platform_manager[state['platform']].categories: + return + if 'cats' in state: + return + await bot.send(event=event, message='请输入要订阅的类别,以空格分隔,支持的类别有:{}'.format( + ','.join(list(platform_manager[state['platform']].categories.values())) + )) + await add_sub.pause() + +@add_sub.handle() +async def add_sub_parse_cat(bot: Bot, event: Event, state: T_State): + if not platform_manager[state['platform']].categories: + return + if 'cats' in state: + return + res = [] + for cat in str(event.get_message()).strip().split(): + if cat not in platform_manager[state['platform']].reverse_category: + await add_sub.reject('不支持 {}'.format(cat)) + res.append(platform_manager[state['platform']].reverse_category[cat]) + state['cats'] = res + +@add_sub.handle() +async def add_sub_handle_tag(bot: Bot, event: Event, state: T_State): + if not platform_manager[state['platform']].enable_tag: + return + if 'tags' in state: + return + await bot.send(event=event, message='请输入要订阅的tag,订阅所有tag输入"全部标签"') + await add_sub.pause() + +@add_sub.handle() +async def add_sub_parse_tag(bot: Bot, event: Event, state: T_State): + if not platform_manager[state['platform']].enable_tag: + return + if 'tags' in state: + return + if str(event.get_message()).strip() == '全部标签': + state['tags'] = [] else: - await add_sub.finish("平台或者id不存在") - + state['tags'] = str(event.get_message()).strip().split() + +@add_sub.handle() +async def add_sub_process(bot: Bot, event: GroupMessageEvent, state: T_State): + config = Config() + config.add_subscribe(event.group_id, user_type='group', + target=state['id'] if platform_manager[state['platform']].has_target else 'default', + target_name=state['name'], target_type=state['platform'], + cats=state.get('cats', []), tags=state.get('tags', [])) + await add_sub.finish('添加 {} 成功'.format(state['name'])) + query_sub = on_command("查询订阅", rule=to_me(), priority=5) @query_sub.handle() -async def _(bot: Bot, event: Event, state: T_State): +async def _(bot: Bot, event: GroupMessageEvent, state: T_State): config: Config = Config() sub_list = config.list_subscribe(event.group_id, "group") res = '订阅的帐号为:\n' for sub in sub_list: - res += '{} {} {}\n'.format(sub['target_type'], sub['target_name'], sub['target']) - send_msgs(bot, event.group_id, 'group', [await parse_text(res)]) - await query_sub.finish() + res += '{} {} {}'.format(sub['target_type'], sub['target_name'], sub['target']) + platform = platform_manager[sub['target_type']] + if platform.categories: + res += ' [{}]'.format(', '.join(map(lambda x: platform.categories[x], sub['cats']))) + if platform.enable_tag: + res += ' {}'.format(', '.join(sub['tags'])) + res += '\n' + # send_msgs(bot, event.group_id, 'group', [await parse_text(res)]) + await query_sub.finish(Message(await parse_text(res))) del_sub = on_command("删除订阅", rule=to_me(), permission=GROUP_ADMIN | GROUP_OWNER, priority=5) @del_sub.handle() -async def _(bot: Bot, event: Event, state: T_State): - args = str(event.get_message()).strip().split() - if len(args) != 2: - await del_sub.finish("使用方法为: 删除订阅 平台 id") - return - target_type, target = args - config = Config() - try: - config.del_subscribe(event.group_id, "group", target, target_type) - except NoSuchSubscribeException: - await del_sub.finish('平台或id不存在') - await del_sub.finish('删除成功') +async def send_list(bot: Bot, event: GroupMessageEvent, state: T_State): + config: Config = Config() + sub_list = config.list_subscribe(event.group_id, "group") + res = '订阅的帐号为:\n' + state['sub_table'] = {} + for index, sub in enumerate(sub_list, 1): + state['sub_table'][index] = {'target_type': sub['target_type'], 'target': sub['target']} + res += '{} {} {} {}\n'.format(index, sub['target_type'], sub['target_name'], sub['target']) + platform = platform_manager[sub['target_type']] + if platform.categories: + res += ' [{}]'.format(', '.join(map(lambda x: platform.categories[x], sub['cats']))) + if platform.enable_tag: + res += ' {}'.format(', '.join(sub['tags'])) + res += '\n' + res += '请输入要删除的订阅的序号' + await bot.send(event=event, message=Message(await parse_text(res))) +@del_sub.receive() +async def do_del(bot, event: GroupMessageEvent, state: T_State): + try: + index = int(str(event.get_message()).strip()) + config = Config() + config.del_subscribe(event.group_id, 'group', **state['sub_table'][index]) + except Exception as e: + await del_sub.reject('删除错误') + logger.warning(e) + else: + await del_sub.finish('删除成功') diff --git a/src/plugins/hk_reporter/platform/__init__.py b/src/plugins/hk_reporter/platform/__init__.py index 4c591db..ea37e30 100644 --- a/src/plugins/hk_reporter/platform/__init__.py +++ b/src/plugins/hk_reporter/platform/__init__.py @@ -2,3 +2,5 @@ from .bilibili import Bilibili from .rss import Rss from .weibo import Weibo from .utils import check_sub_target +from .platform import PlatformNoTarget +from .utils import platform_manager diff --git a/src/plugins/hk_reporter/platform/bilibili.py b/src/plugins/hk_reporter/platform/bilibili.py index 9e1f419..ba5bdba 100644 --- a/src/plugins/hk_reporter/platform/bilibili.py +++ b/src/plugins/hk_reporter/platform/bilibili.py @@ -1,91 +1,92 @@ -from ..utils import Singleton -from ..post import Post from collections import defaultdict -from nonebot import logger -import httpx import json -import time +from typing import Any, Optional -class Bilibili(metaclass=Singleton): - - def __init__(self): - self.exists_posts = defaultdict(set) - self.inited = defaultdict(lambda: False) +import httpx - async def get_user_post_list(self, user_id): +from ..post import Post +from ..types import Category, RawPost, Tag, Target +from .platform import CategoryNotSupport, Platform + +class Bilibili(Platform): + + categories = { + 1: "一般动态", + 2: "专栏文章", + 3: "视频", + 4: "纯文字", + # 5: "短视频" + } + platform_name = 'bilibili' + enable_tag = False + + @staticmethod + async def get_account_name(target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: - params = {'host_uid': user_id, 'offset': 0, 'need_top': 0} + res = await client.get('https://api.bilibili.com/x/space/acc/info', params={'mid': target}) + res_data = json.loads(res.text) + if res_data['code']: + return None + return res_data['data']['name'] + + async def get_sub_list(self, target: Target) -> list[RawPost]: + async with httpx.AsyncClient() as client: + params = {'host_uid': target, 'offset': 0, 'need_top': 0} res = await client.get('https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history', params=params, timeout=4.0) res_dict = json.loads(res.text) if res_dict['code'] == 0: - return res_dict['data'] + return res_dict['data']['cards'] + else: + return [] - def filter(self, data, target, init=False) -> list[Post]: - cards = data['cards'] - res: list[Post] = [] - for card in cards: - dynamic_id = card['desc']['dynamic_id'] - if init: - self.exists_posts[target].add(dynamic_id) - continue - if dynamic_id in self.exists_posts[target]: - continue - if time.time() - card['desc']['timestamp'] > 60 * 60 * 2: - continue - res.append(self.parse(card, target)) - if None in res: - res.remove(None) - return res + def get_id(self, post: RawPost) -> Any: + return post['desc']['dynamic_id'] + + def get_date(self, post: RawPost) -> int: + return post['desc']['timestamp'] + def get_category(self, post: RawPost) -> Category: + post_type = post['desc']['type'] + if post_type == 2: + return Category(1) + elif post_type == 64: + return Category(2) + elif post_type == 8: + return Category(3) + elif post_type == 4: + return Category(4) + elif post_type == 1: + # 转发 + raise CategoryNotSupport() + raise CategoryNotSupport() - def parse(self, card, target) -> Post: - card_content = json.loads(card['card']) - dynamic_id = card['desc']['dynamic_id'] - self.exists_posts[target].add(dynamic_id) - if card['desc']['type'] == 2: + def get_tags(self, raw_post: RawPost) -> list[Tag]: + return [] + + async def parse(self, raw_post: RawPost) -> Post: + card_content = json.loads(raw_post['card']) + post_type = self.get_category(raw_post) + if post_type == 1: # 一般动态 text = card_content['item']['description'] - url = 'https://t.bilibili.com/{}'.format(card['desc']['dynamic_id']) + url = 'https://t.bilibili.com/{}'.format(raw_post['desc']['dynamic_id']) pic = [img['img_src'] for img in card_content['item']['pictures']] - elif card['desc']['type'] == 64: + elif post_type == 2: # 专栏文章 text = '{} {}'.format(card_content['title'], card_content['summary']) - url = 'https://www.bilibili.com/read/cv{}'.format(card['desc']['rid']) + url = 'https://www.bilibili.com/read/cv{}'.format(raw_post['desc']['rid']) pic = card_content['image_urls'] - elif card['desc']['type'] == 8: + elif post_type == 3: # 视频 text = card_content['dynamic'] - url = 'https://www.bilibili.com/video/{}'.format(card['desc']['bvid']) + url = 'https://www.bilibili.com/video/{}'.format(raw_post['desc']['bvid']) pic = [card_content['pic']] - elif card['desc']['type'] == 4: + elif post_type == 4: # 纯文字 text = card_content['item']['content'] - url = 'https://t.bilibili.com/{}'.format(card['desc']['dynamic_id']) + url = 'https://t.bilibili.com/{}'.format(raw_post['desc']['dynamic_id']) pic = [] else: - logger.error(card) - return None + raise CategoryNotSupport(post_type) return Post('bilibili', text, url, pic) - async def fetch_new_post(self, target) -> list[Post]: - try: - post_list_data = await self.get_user_post_list(target) - if self.inited[target]: - return self.filter(post_list_data, target) - else: - self.filter(post_list_data, target, True) - logger.info('bilibili init {} success'.format(target)) - logger.info('post list: {}'.format(self.exists_posts[target])) - self.inited[target] = True - return [] - except httpx.RequestError as err: - logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) - return [] - -async def get_user_info(mid): - async with httpx.AsyncClient() as client: - res = await client.get('https://api.bilibili.com/x/space/acc/info', params={'mid': mid}) - res_data = json.loads(res.text) - if res_data['code']: - return None - return res_data['data']['name'] diff --git a/src/plugins/hk_reporter/platform/platform.py b/src/plugins/hk_reporter/platform/platform.py new file mode 100644 index 0000000..27b0cf0 --- /dev/null +++ b/src/plugins/hk_reporter/platform/platform.py @@ -0,0 +1,285 @@ +import time +from collections import defaultdict +from typing import Any, Optional + +import httpx +from nonebot import logger + +from ..config import Config +from ..post import Post +from ..types import Category, RawPost, Tag, Target, User +from ..utils import Singleton + + +class CategoryNotSupport(Exception): + "raise in get_category, when post category is not supported" + pass + +class PlatformProto(metaclass=Singleton): + + categories: dict[Category, str] + reverse_category: dict[str, Category] + has_target: bool + platform_name: str + enable_tag: bool + async def fetch_new_post(self, target: Target, users: list[User]) -> list[tuple[User, list[Post]]]: + ... + @staticmethod + async def get_account_name(target: Target) -> Optional[str]: + ... + +class Platform(PlatformProto): + "platform with target(account), like weibo, bilibili" + + categories: dict[Category, str] + has_target: bool = True + platform_name: str + enable_tag: bool + + def __init__(self): + self.exists_posts = defaultdict(set) + self.inited = dict() + self.reverse_category = {} + self.cache: dict[Any, Post] = {} + for key, val in self.categories.items(): + self.reverse_category[val] = key + + @staticmethod + async def get_account_name(target: Target) -> Optional[str]: + "Given a tareget, return the username(name) of the target" + raise NotImplementedError() + + async def get_sub_list(self, target: Target) -> list[RawPost]: + "Get post list of the given target" + raise NotImplementedError() + + def get_id(self, post: RawPost) -> Any: + "Get post id of given RawPost" + raise NotImplementedError() + + def get_date(self, post: RawPost) -> Optional[int]: + "Get post timestamp and return, return None if can't get the time" + raise NotImplementedError() + + def get_category(self, post: RawPost) -> Optional[Category]: + "Return category of given Rawpost" + raise NotImplementedError() + + def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]: + "Return Tag list of given RawPost" + raise NotImplementedError() + + async def parse(self, raw_post: RawPost) -> Post: + "parse RawPost into post" + raise NotImplementedError() + + def filter_platform_custom(self, post: RawPost) -> bool: + raise NotImplementedError() + + async def _parse_with_cache(self, post: RawPost) -> Post: + post_id = self.get_id(post) + if post_id not in self.cache: + self.cache[post_id] = await self.parse(post) + return self.cache[post_id] + + async def filter_common(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]: + if not self.inited.get(target, False): + # target not init + for raw_post in raw_post_list: + post_id = self.get_id(raw_post) + self.exists_posts[target].add(post_id) + logger.info('init {}-{} with {}'.format(self.platform_name, target, self.exists_posts[target])) + self.inited[target] = True + return [] + res: list[RawPost] = [] + for raw_post in raw_post_list: + post_id = self.get_id(raw_post) + if post_id in self.exists_posts[target]: + continue + if (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60: + continue + try: + if not self.filter_platform_custom(raw_post): + continue + except NotImplementedError: + pass + try: + self.get_category(raw_post) + except CategoryNotSupport: + continue + except NotImplementedError: + pass + res.append(raw_post) + self.exists_posts[target].add(post_id) + return res + + async def filter_user_custom(self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]) -> list[RawPost]: + res: list[RawPost] = [] + for raw_post in raw_post_list: + if self.categories: + cat = self.get_category(raw_post) + if cats and cat not in cats: + continue + if self.enable_tag and tags: + flag = False + post_tags = self.get_tags(raw_post) + for tag in post_tags: + if tag in tags: + flag = True + break + if not flag: + continue + res.append(raw_post) + return res + + async def fetch_new_post(self, target: Target, users: list[User]) -> list[tuple[User, list[Post]]]: + try: + config = Config() + post_list = await self.get_sub_list(target) + new_posts = await self.filter_common(target, post_list) + res: list[tuple[User, list[Post]]] = [] + if not new_posts: + return [] + else: + for post in new_posts: + logger.info('fetch new post from {} {}: {}'.format(self.platform_name, target, self.get_id(post))) + for user in users: + required_tags = config.get_sub_tags(self.platform_name, target, user.user_type, user.user) if self.enable_tag else [] + cats = config.get_sub_category(self.platform_name, target, user.user_type, user.user) + user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags) + user_post: list[Post] = [] + for raw_post in user_raw_post: + user_post.append(await self._parse_with_cache(raw_post)) + res.append((user, user_post)) + self.cache = {} + return res + except httpx.RequestError as err: + logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) + return [] + + +class PlatformNoTarget(PlatformProto): + + categories: dict[Category, str] + has_target = False + platform_name: str + enable_tag: bool + + def __init__(self): + self.exists_posts = set() + self.inited = False + self.reverse_category = {} + self.cache: dict[Any, Post] = {} + for key, val in self.categories.items(): + self.reverse_category[val] = key + + @staticmethod + async def get_account_name() -> Optional[str]: + "return the username(name) of the target" + raise NotImplementedError() + + async def get_sub_list(self) -> list[RawPost]: + "Get post list of the given target" + raise NotImplementedError() + + def get_id(self, post: RawPost) -> Any: + "Get post id of given RawPost" + raise NotImplementedError() + + def get_date(self, post: RawPost) -> Optional[int]: + "Get post timestamp and return, return None if can't get the time" + raise NotImplementedError() + + def get_category(self, post: RawPost) -> Optional[Category]: + "Return category of given Rawpost" + raise NotImplementedError() + + def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]: + "Return Tag list of given RawPost" + raise NotImplementedError() + + async def parse(self, raw_post: RawPost) -> Post: + "parse RawPost into post" + raise NotImplementedError() + + def filter_platform_custom(self, post: RawPost) -> bool: + raise NotImplementedError() + + async def _parse_with_cache(self, post: RawPost) -> Post: + post_id = self.get_id(post) + if post_id not in self.cache: + self.cache[post_id] = await self.parse(post) + return self.cache[post_id] + + async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]: + if not self.inited: + # target not init + for raw_post in raw_post_list: + post_id = self.get_id(raw_post) + self.exists_posts.add(post_id) + logger.info('init {} with {}'.format(self.platform_name, self.exists_posts)) + self.inited = True + return [] + res: list[RawPost] = [] + for raw_post in raw_post_list: + post_id = self.get_id(raw_post) + if post_id in self.exists_posts: + continue + if (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60: + continue + try: + if not self.filter_platform_custom(raw_post): + continue + except NotImplementedError: + pass + try: + self.get_category(raw_post) + except CategoryNotSupport: + continue + res.append(raw_post) + self.exists_posts.add(post_id) + return res + + async def filter_user_custom(self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]) -> list[RawPost]: + res: list[RawPost] = [] + for raw_post in raw_post_list: + if self.categories: + cat = self.get_category(raw_post) + if cats and cat not in cats: + continue + if self.enable_tag and tags: + flag = False + post_tags = self.get_tags(raw_post) + for tag in post_tags: + if tag in tags: + flag = True + break + if not flag: + continue + res.append(raw_post) + return res + + async def fetch_new_post(self, users: list[User]) -> list[tuple[User, list[Post]]]: + try: + config = Config() + post_list = await self.get_sub_list() + new_posts = await self.filter_common(post_list) + res: list[tuple[User, list[Post]]] = [] + if not new_posts: + return [] + else: + for post in new_posts: + logger.info('fetch new post from {}: {}'.format(self.platform_name, self.get_id(post))) + for user in users: + required_tags = config.get_sub_tags(self.platform_name, 'default', user.user_type, user.user) if self.enable_tag else [] + cats = config.get_sub_category(self.platform_name, 'default', user.user_type, user.user) + user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags) + user_post: list[Post] = [] + for raw_post in user_raw_post: + user_post.append(await self._parse_with_cache(raw_post)) + res.append((user, user_post)) + self.cache = {} + return res + except httpx.RequestError as err: + logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) + return [] diff --git a/src/plugins/hk_reporter/platform/rss.py b/src/plugins/hk_reporter/platform/rss.py index 97b347e..f5c1ac5 100644 --- a/src/plugins/hk_reporter/platform/rss.py +++ b/src/plugins/hk_reporter/platform/rss.py @@ -1,63 +1,41 @@ -from ..utils import Singleton -from ..post import Post -from collections import defaultdict +import calendar +from typing import Any, Optional + from bs4 import BeautifulSoup as bs -from nonebot import logger import feedparser import httpx -import time -import calendar -async def get_rss_raw_data(url) -> str: - async with httpx.AsyncClient() as client: - res = await client.get(url, timeout=10.0) - return res.text +from ..post import Post +from ..types import RawPost, Target +from .platform import Platform -async def get_rss_info(url) -> str: - data = await get_rss_raw_data(url) - feed = feedparser.parse(data) - return feed.feed.title +class Rss(Platform): -class Rss(metaclass=Singleton): + categories = {} + enable_tag = False + platform_name = 'rss' - def __init__(self): - self.exists_posts = defaultdict(set) - self.inited = defaultdict(lambda: False) + @staticmethod + async def get_account_name(target: Target) -> Optional[str]: + async with httpx.AsyncClient() as client: + res = await client.get(target, timeout=10.0) + feed = feedparser.parse(res.text) + return feed['feed']['title'] - def filter(self, data, target, init=False) -> list[Post]: - feed = feedparser.parse(data) - entries = feed.entries - res = [] - for entry in entries: - entry_id = entry.id - if init: - self.exists_posts[target].add(entry_id) - continue - if entry_id in self.exists_posts[target]: - continue - # if time.time() - calendar.timegm(entry.published_parsed) > 2 * 60 * 60: - # continue - res.append(self.parse(entry, target)) - return res + def get_date(self, post: RawPost) -> int: + return calendar.timegm(post.published_parsed) - def parse(self, entry, target) -> Post: - soup = bs(entry.description, 'html.parser') + def get_id(self, post: RawPost) -> Any: + return post.id + + async def get_sub_list(self, target: Target) -> list[RawPost]: + async with httpx.AsyncClient() as client: + res = await client.get(target, timeout=10.0) + feed = feedparser.parse(res) + return feed.entries + + async def parse(self, raw_post: RawPost) -> Post: + soup = bs(raw_post.description, 'html.parser') text = soup.text pics = list(map(lambda x: x.attrs['src'], soup('img'))) - self.exists_posts[target].add(entry.id) - return Post('rss', text, entry.link, pics) - - async def fetch_new_post(self, target) -> list[Post]: - try: - raw_data = await get_rss_raw_data(target) - if self.inited[target]: - return self.filter(raw_data, target) - else: - self.filter(raw_data, target, True) - logger.info('rss init {} success'.format(target)) - logger.info('post list: {}'.format(self.exists_posts[target])) - self.inited[target] = True - return [] - except httpx.RequestError as err: - logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) - return [] + return Post('rss', text, raw_post.link, pics) diff --git a/src/plugins/hk_reporter/platform/utils.py b/src/plugins/hk_reporter/platform/utils.py index 5b3b22c..53449d5 100644 --- a/src/plugins/hk_reporter/platform/utils.py +++ b/src/plugins/hk_reporter/platform/utils.py @@ -1,59 +1,38 @@ -import time -import asyncio import nonebot from nonebot import logger from collections import defaultdict -from .weibo import Weibo, get_user_info as weibo_user_info -from .bilibili import Bilibili, get_user_info as bilibili_user_info -from .rss import Rss, get_rss_info as rss_info -from .arkninghts import Arknights +from typing import Type +from .weibo import Weibo +from .bilibili import Bilibili +from .rss import Rss +from .platform import PlatformProto from ..config import Config from ..post import Post from ..send import send_msgs async def check_sub_target(target_type, target): - if target_type == 'weibo': - return await weibo_user_info(target) - elif target_type == 'bilibili': - return await bilibili_user_info(target) - elif target_type == 'rss': - return await rss_info(target) - elif target_type == 'arknights': - return '明日方舟游戏公告' - else: - return None + return await platform_manager[target_type].get_account_name(target) - -scheduler_last_run = defaultdict(lambda: 0) -async def scheduler(fun, target_type): - platform_interval = { - 'weibo': 3 - } - if (wait_time := time.time() - scheduler_last_run[target_type]) < platform_interval[target_type]: - await asyncio.sleep(wait_time) - await fun() +platform_manager: dict[str, PlatformProto] = { + 'bilibili': Bilibili(), + 'weibo': Weibo(), + 'rss': Rss() + } async def fetch_and_send(target_type: str): config = Config() - platform_manager = { - 'bilibili': Bilibili(), - 'weibo': Weibo(), - 'rss': Rss(), - 'arknights': Arknights() - } target = config.get_next_target(target_type) if not target: return logger.debug('try to fecth new posts from {}, target: {}'.format(target_type, target)) - new_posts: list[Post] = await platform_manager[target_type].fetch_new_post(target) send_list = config.target_user_cache[target_type][target] bot_list = list(nonebot.get_bots().values()) bot = bot_list[0] if bot_list else None - for new_post in new_posts: - logger.warning('get new {} dynamic: {}'.format(target_type, new_post.url)) - logger.warning(new_post) - if not bot: - logger.warning('no bot connected') - else: - for to_send in send_list: - send_msgs(bot, to_send['user'], to_send['user_type'], await new_post.generate_messages()) + to_send = await platform_manager[target_type].fetch_new_post(target, send_list) + for user, send_list in to_send: + for send_post in send_list: + logger.debug('send to {}: {}'.format(user, send_post)) + if not bot: + logger.warning('no bot connected') + else: + send_msgs(bot, user.user, user.user_type, await send_post.generate_messages()) diff --git a/src/plugins/hk_reporter/platform/weibo.py b/src/plugins/hk_reporter/platform/weibo.py index 762a3e2..676dd72 100644 --- a/src/plugins/hk_reporter/platform/weibo.py +++ b/src/plugins/hk_reporter/platform/weibo.py @@ -1,74 +1,74 @@ -import httpx +from collections import defaultdict +from datetime import datetime import json import time -from collections import defaultdict +from typing import Any, Optional + from bs4 import BeautifulSoup as bs -from datetime import datetime +import httpx from nonebot import logger -from ..utils import Singleton from ..post import Post +from ..types import * +from ..utils import Singleton +from .platform import Platform -class Weibo(metaclass=Singleton): +class Weibo(Platform): - def __init__(self): - self.exists_posts = defaultdict(set) - self.inited = defaultdict(lambda: False) + categories = { + 1: '转发', + 2: '视频', + 3: '图文' + } + enable_tag = False + platform_name = 'weibo' - async def get_user_post_list(self, weibo_id: str): + @staticmethod + async def get_account_name(target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: - params = { 'containerid': '107603' + weibo_id} + param = {'containerid': '100505' + target} + res = await client.get('https://m.weibo.cn/api/container/getIndex', params=param) + res_dict = json.loads(res.text) + if res_dict.get('ok') == 1: + return res_dict['data']['userInfo']['screen_name'] + else: + return None + + async def get_sub_list(self, target: Target) -> list[RawPost]: + async with httpx.AsyncClient() as client: + params = { 'containerid': '107603' + target} res = await client.get('https://m.weibo.cn/api/container/getIndex?', params=params, timeout=4.0) - return res.text + res_data = json.loads(res.text) + if not res_data['ok']: + return [] + return res_data['data']['cards'] - def filter_weibo(self, weibo_raw_text, target, init=False): - weibo_dict = json.loads(weibo_raw_text) - weibos = weibo_dict['data']['cards'] - res: list[Post] = [] - for weibo in weibos: - if weibo['card_type'] != 9: - continue - info = weibo['mblog'] - if init: - self.exists_posts[target].add(info['id']) - continue - if info['id'] in self.exists_posts[target]: - continue - created_time = datetime.strptime(info['created_at'], '%a %b %d %H:%M:%S %z %Y') - if time.time() - created_time.timestamp() > 60 * 60 * 2: - continue - res.append(self.parse_weibo(weibo, target)) - return res + def get_id(self, post: RawPost) -> Any: + return post['mblog']['id'] - def parse_weibo(self, weibo_dict, target): - info = weibo_dict['mblog'] + def filter_platform_custom(self, raw_post: RawPost) -> bool: + return raw_post['card_type'] == 9 + + def get_date(self, raw_post: RawPost) -> float: + created_time = datetime.strptime(raw_post['mblog']['created_at'], '%a %b %d %H:%M:%S %z %Y') + return created_time.timestamp() + + def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]: + "Return Tag list of given RawPost" + return None + + def get_category(self, raw_post: RawPost) -> Category: + if raw_post['mblog'].get('retweeted_status'): + return Category(1) + elif raw_post['mblog'].get('page_info') and raw_post['mblog']['page_info'].get('type') == 'video': + return Category(2) + else: + return Category(3) + + async def parse(self, raw_post: RawPost) -> Post: + info = raw_post['mblog'] parsed_text = bs(info['text'], 'html.parser').text pic_urls = [img['large']['url'] for img in info.get('pics', [])] - self.exists_posts[target].add(info['id']) detail_url = 'https://weibo.com/{}/{}'.format(info['user']['id'], info['bid']) # return parsed_text, detail_url, pic_urls return Post('weibo', parsed_text, detail_url, pic_urls) - - async def fetch_new_post(self, target): - try: - post_list = await self.get_user_post_list(target) - if not self.inited[target]: - self.filter_weibo(post_list, target, True) - logger.info('weibo init {} success'.format(target)) - logger.info('post list: {}'.format(self.exists_posts[target])) - self.inited[target] = True - return [] - return self.filter_weibo(post_list, target) - except httpx.RequestError as err: - logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) - return [] - -async def get_user_info(id): - async with httpx.AsyncClient() as client: - param = {'containerid': '100505' + id} - res = await client.get('https://m.weibo.cn/api/container/getIndex', params=param) - res_dict = json.loads(res.text) - if res_dict.get('ok') == 1: - return res_dict['data']['userInfo']['screen_name'] - else: - return None diff --git a/src/plugins/hk_reporter/plugin_config.py b/src/plugins/hk_reporter/plugin_config.py index 50c64bb..83554ff 100644 --- a/src/plugins/hk_reporter/plugin_config.py +++ b/src/plugins/hk_reporter/plugin_config.py @@ -1,4 +1,5 @@ from pydantic import BaseSettings +import nonebot class PlugConfig(BaseSettings): @@ -8,3 +9,6 @@ class PlugConfig(BaseSettings): class Config: extra = 'ignore' + +global_config = nonebot.get_driver().config +plugin_config = PlugConfig(**global_config.dict()) diff --git a/src/plugins/hk_reporter/post.py b/src/plugins/hk_reporter/post.py index ce22381..1003a44 100644 --- a/src/plugins/hk_reporter/post.py +++ b/src/plugins/hk_reporter/post.py @@ -1,12 +1,14 @@ -from . import plugin_config +from dataclasses import dataclass, field +from .plugin_config import plugin_config from .utils import parse_text +@dataclass class Post: target_type: str text: str url: str - pics: list[str] + pics: list[str] = field(default_factory=list) async def generate_messages(self): if plugin_config.hk_reporter_use_pic: @@ -24,11 +26,5 @@ class Post: res.append("[CQ:image,file={url}]".format(url=pic)) return res - def __init__(self, target_type, text, url, pics=[]): - self.target_type = target_type - self.text = text - self.url = url - self.pics = pics - def __str__(self): return 'type: {}\ntext: {}\nurl: {}\npic: {}'.format(self.target_type, self.text[:50], self.url, ','.join(map(lambda x: 'b64img' if x.startswith('base64') else x, self.pics))) diff --git a/src/plugins/hk_reporter/types.py b/src/plugins/hk_reporter/types.py new file mode 100644 index 0000000..8f69e6f --- /dev/null +++ b/src/plugins/hk_reporter/types.py @@ -0,0 +1,12 @@ +from typing import Any, NewType +from dataclasses import dataclass + +RawPost = NewType('RawPost', Any) +Target = NewType('Target', str) +Category = NewType('Category', int) +Tag = NewType('Tag', str) + +@dataclass +class User: + user: str + user_type: str diff --git a/src/plugins/hk_reporter/utils.py b/src/plugins/hk_reporter/utils.py index 41ab185..c351c69 100644 --- a/src/plugins/hk_reporter/utils.py +++ b/src/plugins/hk_reporter/utils.py @@ -1,5 +1,6 @@ import os import asyncio +from typing import Optional import nonebot from nonebot import logger import base64 @@ -7,7 +8,7 @@ from pyppeteer import launch from html import escape from hashlib import sha256 -from . import plugin_config +from .plugin_config import plugin_config class Singleton(type): _instances = {} @@ -23,7 +24,7 @@ class Render(metaclass=Singleton): def __init__(self): self.lock = asyncio.Lock() - async def render(self, url: str, viewport: dict = None, target: str = None) -> str: + async def render(self, url: str, viewport: Optional[dict] = None, target: Optional[str] = None) -> str: async with self.lock: if plugin_config.hk_reporter_use_local: browser = await launch(executablePath='/usr/bin/chromium', args=['--no-sandbox']) @@ -40,7 +41,7 @@ class Render(metaclass=Singleton): data = await page.screenshot(type='jpeg', encoding='base64') await page.close() await browser.close() - return data + return str(data) async def text_to_pic(self, text: str) -> str: hash_text = sha256(text.encode()).hexdigest()[:20]