From 709a3e214b92168a122edd1737c498828ab7979d Mon Sep 17 00:00:00 2001 From: felinae98 <731499577@qq.com> Date: Sat, 12 Feb 2022 10:20:02 +0800 Subject: [PATCH] format code --- .pre-commit-config.yaml | 1 + bot.py | 6 +- src/plugins/auto_agree.py | 24 +- src/plugins/nonebot_bison/__init__.py | 27 +- .../nonebot_bison/admin_page/__init__.py | 95 ++-- src/plugins/nonebot_bison/admin_page/api.py | 155 ++++--- src/plugins/nonebot_bison/admin_page/jwt.py | 19 +- .../nonebot_bison/admin_page/token_manager.py | 16 +- src/plugins/nonebot_bison/config.py | 157 +++++-- src/plugins/nonebot_bison/config_manager.py | 226 +++++---- .../nonebot_bison/platform/__init__.py | 16 +- .../nonebot_bison/platform/arknights.py | 144 +++--- .../nonebot_bison/platform/bilibili.py | 109 +++-- .../nonebot_bison/platform/ncm_artist.py | 44 +- .../nonebot_bison/platform/ncm_radio.py | 46 +- .../nonebot_bison/platform/platform.py | 154 ++++-- src/plugins/nonebot_bison/platform/rss.py | 27 +- src/plugins/nonebot_bison/platform/wechat.py | 6 +- src/plugins/nonebot_bison/platform/weibo.py | 153 +++--- src/plugins/nonebot_bison/plugin_config.py | 14 +- src/plugins/nonebot_bison/post.py | 69 +-- src/plugins/nonebot_bison/scheduler.py | 71 ++- src/plugins/nonebot_bison/send.py | 19 +- src/plugins/nonebot_bison/types.py | 12 +- tests/conftest.py | 29 +- tests/platforms/test_arknights.py | 93 ++-- tests/platforms/test_bilibili.py | 46 +- tests/platforms/test_ncm_artist.py | 53 ++- tests/platforms/test_ncm_radio.py | 58 +-- tests/platforms/test_platform.py | 437 ++++++++++-------- tests/platforms/test_weibo.py | 133 +++--- tests/platforms/utils.py | 16 +- tests/test_config_manager.py | 59 ++- tests/test_merge_pic.py | 96 ++-- tests/test_render.py | 15 +- 35 files changed, 1613 insertions(+), 1032 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b577ac8..77e9bc6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,6 +9,7 @@ repos: rev: 5.10.1 hooks: - id: isort + args: ["--profile", "black", "--filter-files"] - repo: https://github.com/psf/black rev: 22.1.0 diff --git a/bot.py b/bot.py index 6819d73..957a2e5 100644 --- a/bot.py +++ b/bot.py @@ -5,11 +5,11 @@ nonebot.init(command_start=[""]) app = nonebot.get_asgi() driver = nonebot.get_driver() -driver.register_adapter('cqhttp', CQHTTPBot) +driver.register_adapter("cqhttp", CQHTTPBot) nonebot.load_builtin_plugins() -nonebot.load_plugin('nonebot_plugin_help') -nonebot.load_plugins('src/plugins') +nonebot.load_plugin("nonebot_plugin_help") +nonebot.load_plugins("src/plugins") if __name__ == "__main__": nonebot.run(app="bot:app") diff --git a/src/plugins/auto_agree.py b/src/plugins/auto_agree.py index d4b4998..463e54a 100644 --- a/src/plugins/auto_agree.py +++ b/src/plugins/auto_agree.py @@ -1,18 +1,26 @@ -from nonebot import on_request, logger +from nonebot import logger, on_request from nonebot.adapters.cqhttp import Bot, Event +from nonebot.adapters.cqhttp.event import ( + FriendRequestEvent, + GroupRequestEvent, + RequestEvent, +) +from nonebot.adapters.cqhttp.permission import PRIVATE_FRIEND from nonebot.permission import SUPERUSER from nonebot.typing import T_State -from nonebot.adapters.cqhttp.permission import PRIVATE_FRIEND -from nonebot.adapters.cqhttp.event import FriendRequestEvent, GroupRequestEvent, RequestEvent friend_req = on_request(priority=5) + @friend_req.handle() async def add_superuser(bot: Bot, event: RequestEvent, state: T_State): - if str(event.user_id) in bot.config.superusers and event.request_type == 'private': + if str(event.user_id) in bot.config.superusers and event.request_type == "private": await event.approve(bot) - logger.info('add user {}'.format(event.user_id)) - elif event.sub_type == 'invite' and str(event.user_id) in bot.config.superusers and event.request_type == 'group': + logger.info("add user {}".format(event.user_id)) + elif ( + event.sub_type == "invite" + and str(event.user_id) in bot.config.superusers + and event.request_type == "group" + ): await event.approve(bot) - logger.info('add group {}'.format(event.group_id)) - + logger.info("add group {}".format(event.group_id)) diff --git a/src/plugins/nonebot_bison/__init__.py b/src/plugins/nonebot_bison/__init__.py index 73abbeb..353a8a3 100644 --- a/src/plugins/nonebot_bison/__init__.py +++ b/src/plugins/nonebot_bison/__init__.py @@ -1,16 +1,17 @@ import nonebot -from . import config_manager -from . import config -from . import scheduler -from . import send -from . import post -from . import platform -from . import types -from . import utils -from . import admin_page +from . import ( + admin_page, + config, + config_manager, + platform, + post, + scheduler, + send, + types, + utils, +) -__help__version__ = '0.4.3' -__help__plugin__name__ = 'nonebot_bison' -__usage__ = ('本bot可以提供b站、微博等社交媒体的消息订阅,详情' - '请查看本bot文档,或者at本bot发送“添加订阅”订阅第一个帐号') +__help__version__ = "0.4.3" +__help__plugin__name__ = "nonebot_bison" +__usage__ = "本bot可以提供b站、微博等社交媒体的消息订阅,详情" "请查看本bot文档,或者at本bot发送“添加订阅”订阅第一个帐号" diff --git a/src/plugins/nonebot_bison/admin_page/__init__.py b/src/plugins/nonebot_bison/admin_page/__init__.py index 745e0ce..0aba931 100644 --- a/src/plugins/nonebot_bison/admin_page/__init__.py +++ b/src/plugins/nonebot_bison/admin_page/__init__.py @@ -1,8 +1,9 @@ -from dataclasses import dataclass import os +from dataclasses import dataclass from pathlib import Path from typing import Union +import socketio from fastapi.staticfiles import StaticFiles from nonebot import get_driver, on_command from nonebot.adapters.cqhttp.bot import Bot @@ -11,7 +12,6 @@ from nonebot.drivers.fastapi import Driver from nonebot.log import logger from nonebot.rule import to_me from nonebot.typing import T_State -import socketio from ..plugin_config import plugin_config from .api import ( @@ -27,21 +27,21 @@ from .api import ( from .jwt import load_jwt from .token_manager import token_manager as tm -URL_BASE = '/bison/' -GLOBAL_CONF_URL = f'{URL_BASE}api/global_conf' -AUTH_URL = f'{URL_BASE}api/auth' -SUBSCRIBE_URL = f'{URL_BASE}api/subs' -GET_TARGET_NAME_URL = f'{URL_BASE}api/target_name' -TEST_URL = f'{URL_BASE}test' +URL_BASE = "/bison/" +GLOBAL_CONF_URL = f"{URL_BASE}api/global_conf" +AUTH_URL = f"{URL_BASE}api/auth" +SUBSCRIBE_URL = f"{URL_BASE}api/subs" +GET_TARGET_NAME_URL = f"{URL_BASE}api/target_name" +TEST_URL = f"{URL_BASE}test" STATIC_PATH = (Path(__file__).parent / "dist").resolve() sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") socket_app = socketio.ASGIApp(sio, socketio_path="socket") -class SinglePageApplication(StaticFiles): - def __init__(self, directory: os.PathLike, index='index.html'): +class SinglePageApplication(StaticFiles): + def __init__(self, directory: os.PathLike, index="index.html"): self.index = index super().__init__(directory=directory, packages=None, html=True, check_dir=True) @@ -51,12 +51,13 @@ class SinglePageApplication(StaticFiles): return await super().lookup_path(self.index) return (full_path, stat_res) -def register_router_fastapi(driver: Driver, socketio): - from fastapi.security import OAuth2PasswordBearer - from fastapi.param_functions import Depends - from fastapi import HTTPException, status - oath_scheme = OAuth2PasswordBearer(tokenUrl='token') +def register_router_fastapi(driver: Driver, socketio): + from fastapi import HTTPException, status + from fastapi.param_functions import Depends + from fastapi.security import OAuth2PasswordBearer + + oath_scheme = OAuth2PasswordBearer(tokenUrl="token") async def get_jwt_obj(token: str = Depends(oath_scheme)): obj = load_jwt(token) @@ -64,10 +65,12 @@ def register_router_fastapi(driver: Driver, socketio): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) return obj - async def check_group_permission(groupNumber: str, token_obj: dict = Depends(get_jwt_obj)): - groups = token_obj['groups'] + async def check_group_permission( + groupNumber: str, token_obj: dict = Depends(get_jwt_obj) + ): + groups = token_obj["groups"] for group in groups: - if int(groupNumber) == group['id']: + if int(groupNumber) == group["id"]: return raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) @@ -88,17 +91,35 @@ def register_router_fastapi(driver: Driver, socketio): @app.get(SUBSCRIBE_URL) async def subs(jwt_obj: dict = Depends(get_jwt_obj)): return await get_subs_info(jwt_obj) + @app.get(GET_TARGET_NAME_URL) - async def _get_target_name(platformName: str, target: str, jwt_obj: dict = Depends(get_jwt_obj)): + async def _get_target_name( + platformName: str, target: str, jwt_obj: dict = Depends(get_jwt_obj) + ): return await get_target_name(platformName, target, jwt_obj) + @app.post(SUBSCRIBE_URL, dependencies=[Depends(check_group_permission)]) async def _add_group_subs(groupNumber: str, req: AddSubscribeReq): - return await add_group_sub(group_number=groupNumber, platform_name=req.platformName, - target=req.target, target_name=req.targetName, cats=req.cats, tags=req.tags) + return await add_group_sub( + group_number=groupNumber, + platform_name=req.platformName, + target=req.target, + target_name=req.targetName, + cats=req.cats, + tags=req.tags, + ) + @app.patch(SUBSCRIBE_URL, dependencies=[Depends(check_group_permission)]) async def _update_group_subs(groupNumber: str, req: AddSubscribeReq): - return await update_group_sub(group_number=groupNumber, platform_name=req.platformName, - target=req.target, target_name=req.targetName, cats=req.cats, tags=req.tags) + return await update_group_sub( + group_number=groupNumber, + platform_name=req.platformName, + target=req.target, + target_name=req.targetName, + cats=req.cats, + tags=req.tags, + ) + @app.delete(SUBSCRIBE_URL, dependencies=[Depends(check_group_permission)]) async def _del_group_subs(groupNumber: str, target: str, platformName: str): return await del_group_sub(groupNumber, platformName, target) @@ -108,8 +129,8 @@ def register_router_fastapi(driver: Driver, socketio): def init(): driver = get_driver() - if driver.type == 'fastapi': - assert(isinstance(driver, Driver)) + if driver.type == "fastapi": + assert isinstance(driver, Driver) register_router_fastapi(driver, socket_app) else: logger.warning(f"Driver {driver.type} not supported") @@ -118,19 +139,25 @@ def init(): port = driver.config.port if host in ["0.0.0.0", "127.0.0.1"]: host = "localhost" - logger.opt(colors=True).info(f"Nonebot test frontend will be running at: " - f"http://{host}:{port}{URL_BASE}") + logger.opt(colors=True).info( + f"Nonebot test frontend will be running at: " + f"http://{host}:{port}{URL_BASE}" + ) -if (STATIC_PATH / 'index.html').exists(): + +if (STATIC_PATH / "index.html").exists(): init() - get_token = on_command('后台管理', rule=to_me(), priority=5) + get_token = on_command("后台管理", rule=to_me(), priority=5) + @get_token.handle() async def send_token(bot: "Bot", event: PrivateMessageEvent, state: T_State): token = tm.get_user_token((event.get_user_id(), event.sender.nickname)) - await get_token.finish(f'请访问: {plugin_config.bison_outer_url}auth/{token}') - get_token.__help__name__ = '获取后台管理地址' - get_token.__help__info__ = ('获取管理bot后台的地址,该地址会' - '在一段时间过后过期,请不要泄漏该地址') + await get_token.finish(f"请访问: {plugin_config.bison_outer_url}auth/{token}") + + get_token.__help__name__ = "获取后台管理地址" + get_token.__help__info__ = "获取管理bot后台的地址,该地址会" "在一段时间过后过期,请不要泄漏该地址" else: - logger.warning("Frontend file not found, please compile it or use docker or pypi version") + logger.warning( + "Frontend file not found, please compile it or use docker or pypi version" + ) diff --git a/src/plugins/nonebot_bison/admin_page/api.py b/src/plugins/nonebot_bison/admin_page/api.py index ff17d62..b2f519f 100644 --- a/src/plugins/nonebot_bison/admin_page/api.py +++ b/src/plugins/nonebot_bison/admin_page/api.py @@ -6,110 +6,141 @@ from ..platform import check_sub_target, platform_manager from .jwt import pack_jwt from .token_manager import token_manager + async def test(): return {"status": 200, "text": "test"} + async def get_global_conf(): res = {} for platform_name, platform in platform_manager.items(): res[platform_name] = { - 'platformName': platform_name, - 'categories': platform.categories, - 'enabledTag': platform.enable_tag, - 'name': platform.name, - 'hasTarget': getattr(platform, 'has_target') - } - return { 'platformConf': res } + "platformName": platform_name, + "categories": platform.categories, + "enabledTag": platform.enable_tag, + "name": platform.name, + "hasTarget": getattr(platform, "has_target"), + } + return {"platformConf": res} + async def get_admin_groups(qq: int): bot = nonebot.get_bot() - groups = await bot.call_api('get_group_list') + groups = await bot.call_api("get_group_list") res = [] for group in groups: - group_id = group['group_id'] - users = await bot.call_api('get_group_member_list', group_id=group_id) + group_id = group["group_id"] + users = await bot.call_api("get_group_member_list", group_id=group_id) for user in users: - if user['user_id'] == qq and user['role'] in ('owner', 'admin'): - res.append({'id': group_id, 'name': group['group_name']}) + if user["user_id"] == qq and user["role"] in ("owner", "admin"): + res.append({"id": group_id, "name": group["group_name"]}) return res + async def auth(token: str): if qq_tuple := token_manager.get_user(token): qq, nickname = qq_tuple bot = nonebot.get_bot() - assert(isinstance(bot, Bot)) - groups = await bot.call_api('get_group_list') + assert isinstance(bot, Bot) + groups = await bot.call_api("get_group_list") if str(qq) in nonebot.get_driver().config.superusers: jwt_obj = { - 'id': str(qq), - 'groups': list(map( - lambda info: {'id': info['group_id'], 'name': info['group_name']}, - groups)), - } + "id": str(qq), + "groups": list( + map( + lambda info: { + "id": info["group_id"], + "name": info["group_name"], + }, + groups, + ) + ), + } ret_obj = { - 'type': 'admin', - 'name': nickname, - 'id': str(qq), - 'token': pack_jwt(jwt_obj) - } - return { 'status': 200, **ret_obj } + "type": "admin", + "name": nickname, + "id": str(qq), + "token": pack_jwt(jwt_obj), + } + return {"status": 200, **ret_obj} if admin_groups := await get_admin_groups(int(qq)): - jwt_obj = { - 'id': str(qq), - 'groups': admin_groups - } + jwt_obj = {"id": str(qq), "groups": admin_groups} ret_obj = { - 'type': 'user', - 'name': nickname, - 'id': str(qq), - 'token': pack_jwt(jwt_obj) - } - return { 'status': 200, **ret_obj } + "type": "user", + "name": nickname, + "id": str(qq), + "token": pack_jwt(jwt_obj), + } + return {"status": 200, **ret_obj} else: - return { 'status': 400, 'type': '', 'name': '', 'id': '', 'token': '' } + return {"status": 400, "type": "", "name": "", "id": "", "token": ""} else: - return { 'status': 400, 'type': '', 'name': '', 'id': '', 'token': '' } + return {"status": 400, "type": "", "name": "", "id": "", "token": ""} + async def get_subs_info(jwt_obj: dict): - groups = jwt_obj['groups'] + groups = jwt_obj["groups"] res = {} for group in groups: - group_id = group['id'] + group_id = group["id"] config = Config() - subs = list(map(lambda sub: { - 'platformName': sub['target_type'], 'target': sub['target'], 'targetName': sub['target_name'], 'cats': sub['cats'], 'tags': sub['tags'] - }, config.list_subscribe(group_id, 'group'))) - res[group_id] = { - 'name': group['name'], - 'subscribes': subs - } + subs = list( + map( + lambda sub: { + "platformName": sub["target_type"], + "target": sub["target"], + "targetName": sub["target_name"], + "cats": sub["cats"], + "tags": sub["tags"], + }, + config.list_subscribe(group_id, "group"), + ) + ) + res[group_id] = {"name": group["name"], "subscribes": subs} return res -async def get_target_name(platform_name: str, target: str, jwt_obj: dict): - return {'targetName': await check_sub_target(platform_name, target)} -async def add_group_sub(group_number: str, platform_name: str, target: str, - target_name: str, cats: list[int], tags: list[str]): +async def get_target_name(platform_name: str, target: str, jwt_obj: dict): + return {"targetName": await check_sub_target(platform_name, target)} + + +async def add_group_sub( + group_number: str, + platform_name: str, + target: str, + target_name: str, + cats: list[int], + tags: list[str], +): config = Config() - config.add_subscribe(int(group_number), 'group', target, target_name, platform_name, cats, tags) - return { 'status': 200, 'msg': '' } + config.add_subscribe( + int(group_number), "group", target, target_name, platform_name, cats, tags + ) + return {"status": 200, "msg": ""} + async def del_group_sub(group_number: str, platform_name: str, target: str): config = Config() try: - config.del_subscribe(int(group_number), 'group', target, platform_name) + config.del_subscribe(int(group_number), "group", target, platform_name) except (NoSuchUserException, NoSuchSubscribeException): - return { 'status': 400, 'msg': '删除错误' } - return { 'status': 200, 'msg': '' } + return {"status": 400, "msg": "删除错误"} + return {"status": 200, "msg": ""} -async def update_group_sub(group_number: str, platform_name: str, target: str, - target_name: str, cats: list[int], tags: list[str]): +async def update_group_sub( + group_number: str, + platform_name: str, + target: str, + target_name: str, + cats: list[int], + tags: list[str], +): config = Config() try: - config.update_subscribe(int(group_number), 'group', - target, target_name, platform_name, cats, tags) + config.update_subscribe( + int(group_number), "group", target, target_name, platform_name, cats, tags + ) except (NoSuchUserException, NoSuchSubscribeException): - return { 'status': 400, 'msg': '更新错误' } - return { 'status': 200, 'msg': '' } - + return {"status": 400, "msg": "更新错误"} + return {"status": 200, "msg": ""} diff --git a/src/plugins/nonebot_bison/admin_page/jwt.py b/src/plugins/nonebot_bison/admin_page/jwt.py index c607747..661621a 100644 --- a/src/plugins/nonebot_bison/admin_page/jwt.py +++ b/src/plugins/nonebot_bison/admin_page/jwt.py @@ -1,20 +1,23 @@ +import datetime import random import string from typing import Optional -import jwt -import datetime -_key = ''.join(random.SystemRandom().choice(string.ascii_letters) for _ in range(16)) +import jwt + +_key = "".join(random.SystemRandom().choice(string.ascii_letters) for _ in range(16)) + def pack_jwt(obj: dict) -> str: return jwt.encode( - {'exp': datetime.datetime.utcnow() + datetime.timedelta(hours=1), **obj}, - _key, algorithm='HS256' - ) + {"exp": datetime.datetime.utcnow() + datetime.timedelta(hours=1), **obj}, + _key, + algorithm="HS256", + ) + def load_jwt(token: str) -> Optional[dict]: try: - return jwt.decode(token, _key, algorithms=['HS256']) + return jwt.decode(token, _key, algorithms=["HS256"]) except: return None - diff --git a/src/plugins/nonebot_bison/admin_page/token_manager.py b/src/plugins/nonebot_bison/admin_page/token_manager.py index 47e3927..e540656 100644 --- a/src/plugins/nonebot_bison/admin_page/token_manager.py +++ b/src/plugins/nonebot_bison/admin_page/token_manager.py @@ -1,24 +1,26 @@ -from typing import Optional -from expiringdict import ExpiringDict import random -import string +import string +from typing import Optional + +from expiringdict import ExpiringDict + class TokenManager: - def __init__(self): - self.token_manager = ExpiringDict(max_len=100, max_age_seconds=60*10) + self.token_manager = ExpiringDict(max_len=100, max_age_seconds=60 * 10) def get_user(self, token: str) -> Optional[tuple]: res = self.token_manager.get(token) - assert(res is None or isinstance(res, tuple)) + assert res is None or isinstance(res, tuple) return res def save_user(self, token: str, qq: tuple) -> None: self.token_manager[token] = qq def get_user_token(self, qq: tuple) -> str: - token = ''.join(random.choices(string.ascii_letters + string.digits, k=16)) + token = "".join(random.choices(string.ascii_letters + string.digits, k=16)) self.save_user(token, qq) return token + token_manager = TokenManager() diff --git a/src/plugins/nonebot_bison/config.py b/src/plugins/nonebot_bison/config.py index 0712fcb..a021aeb 100644 --- a/src/plugins/nonebot_bison/config.py +++ b/src/plugins/nonebot_bison/config.py @@ -1,6 +1,6 @@ +import os from collections import defaultdict from os import path -import os from typing import DefaultDict, Literal, Mapping, TypedDict import nonebot @@ -14,26 +14,30 @@ from .utils import Singleton supported_target_type = platform_manager.keys() + def get_config_path() -> str: if plugin_config.bison_config_path: data_dir = plugin_config.bison_config_path else: working_dir = os.getcwd() - data_dir = path.join(working_dir, 'data') + data_dir = path.join(working_dir, "data") if not path.isdir(data_dir): os.makedirs(data_dir) - old_path = path.join(data_dir, 'hk_reporter.json') - new_path = path.join(data_dir, 'bison.json') + old_path = path.join(data_dir, "hk_reporter.json") + new_path = path.join(data_dir, "bison.json") if os.path.exists(old_path) and not os.path.exists(new_path): os.rename(old_path, new_path) return new_path + class NoSuchUserException(Exception): pass + class NoSuchSubscribeException(Exception): pass + class SubscribeContent(TypedDict): target: str target_type: str @@ -41,75 +45,105 @@ class SubscribeContent(TypedDict): cats: list[int] tags: list[str] + class ConfigContent(TypedDict): user: str user_type: Literal["group", "private"] subs: list[SubscribeContent] + class Config(metaclass=Singleton): migrate_version = 2 - + def __init__(self): - self.db = TinyDB(get_config_path(), encoding='utf-8') - self.kv_config = self.db.table('kv') - self.user_target = self.db.table('user_target') + self.db = TinyDB(get_config_path(), encoding="utf-8") + self.kv_config = self.db.table("kv") + self.user_target = self.db.table("user_target") self.target_user_cache: dict[str, defaultdict[Target, list[User]]] = {} self.target_user_cat_cache = {} self.target_user_tag_cache = {} self.target_list = {} self.next_index: DefaultDict[str, int] = defaultdict(lambda: 0) - - def add_subscribe(self, user, user_type, target, target_name, target_type, cats, tags): + + def add_subscribe( + self, user, user_type, target, target_name, target_type, cats, tags + ): user_query = Query() query = (user_query.user == user) & (user_query.user_type == user_type) - if (user_data := self.user_target.get(query)): + if user_data := self.user_target.get(query): # update - subs: list = user_data.get('subs', []) - subs.append({"target": target, "target_type": target_type, 'target_name': target_name, 'cats': cats, 'tags': tags}) + subs: list = user_data.get("subs", []) + subs.append( + { + "target": target, + "target_type": target_type, + "target_name": target_name, + "cats": cats, + "tags": tags, + } + ) self.user_target.update({"subs": subs}, query) else: # insert - self.user_target.insert({ - 'user': user, 'user_type': user_type, - 'subs': [{'target': target, 'target_type': target_type, 'target_name': target_name, 'cats': cats, 'tags': tags }] - }) + self.user_target.insert( + { + "user": user, + "user_type": user_type, + "subs": [ + { + "target": target, + "target_type": target_type, + "target_name": target_name, + "cats": cats, + "tags": tags, + } + ], + } + ) self.update_send_cache() def list_subscribe(self, user, user_type) -> list[SubscribeContent]: query = Query() - if user_sub := self.user_target.get((query.user == user) & (query.user_type ==user_type)): - return user_sub['subs'] + if user_sub := self.user_target.get( + (query.user == user) & (query.user_type == user_type) + ): + return user_sub["subs"] return [] def get_all_subscribe(self): return self.user_target - + def del_subscribe(self, user, user_type, target, target_type): user_query = Query() query = (user_query.user == user) & (user_query.user_type == user_type) if not (query_res := self.user_target.get(query)): raise NoSuchUserException() - subs = query_res.get('subs', []) + subs = query_res.get("subs", []) for idx, sub in enumerate(subs): - if sub.get('target') == target and sub.get('target_type') == target_type: + if sub.get("target") == target and sub.get("target_type") == target_type: subs.pop(idx) - self.user_target.update({'subs': subs}, query) + self.user_target.update({"subs": subs}, query) self.update_send_cache() return raise NoSuchSubscribeException() - def update_subscribe(self, user, user_type, target, target_name, target_type, cats, tags): + def update_subscribe( + self, user, user_type, target, target_name, target_type, cats, tags + ): user_query = Query() query = (user_query.user == user) & (user_query.user_type == user_type) - if (user_data := self.user_target.get(query)): + if user_data := self.user_target.get(query): # update - subs: list = user_data.get('subs', []) + subs: list = user_data.get("subs", []) find_flag = False for item in subs: - if item['target'] == target and item['target_type'] == target_type: - item['target_name'], item['cats'], item['tags'] = \ - target_name, cats, tags + if item["target"] == target and item["target_type"] == target_type: + item["target_name"], item["cats"], item["tags"] = ( + target_name, + cats, + tags, + ) find_flag = True break if not find_flag: @@ -121,33 +155,58 @@ class Config(metaclass=Singleton): def update_send_cache(self): res = {target_type: defaultdict(list) for target_type in supported_target_type} - cat_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} - tag_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} + cat_res = { + target_type: defaultdict(lambda: defaultdict(list)) + for target_type in supported_target_type + } + tag_res = { + target_type: defaultdict(lambda: defaultdict(list)) + for target_type in supported_target_type + } # res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} to_del = [] for user in self.user_target.all(): - for sub in user.get('subs', []): - if not sub.get('target_type') in supported_target_type: - to_del.append({'user': user['user'], 'user_type': user['user_type'], 'target': sub['target'], 'target_type': sub['target_type']}) + for sub in user.get("subs", []): + if not sub.get("target_type") in supported_target_type: + to_del.append( + { + "user": user["user"], + "user_type": user["user_type"], + "target": sub["target"], + "target_type": sub["target_type"], + } + ) continue - res[sub['target_type']][sub['target']].append(User(user['user'], user['user_type'])) - cat_res[sub['target_type']][sub['target']]['{}-{}'.format(user['user_type'], user['user'])] = sub['cats'] - tag_res[sub['target_type']][sub['target']]['{}-{}'.format(user['user_type'], user['user'])] = sub['tags'] + res[sub["target_type"]][sub["target"]].append( + User(user["user"], user["user_type"]) + ) + cat_res[sub["target_type"]][sub["target"]][ + "{}-{}".format(user["user_type"], user["user"]) + ] = sub["cats"] + tag_res[sub["target_type"]][sub["target"]][ + "{}-{}".format(user["user_type"], user["user"]) + ] = sub["tags"] self.target_user_cache = res self.target_user_cat_cache = cat_res self.target_user_tag_cache = tag_res for target_type in self.target_user_cache: - self.target_list[target_type] = list(self.target_user_cache[target_type].keys()) - - logger.info(f'Deleting {to_del}') + self.target_list[target_type] = list( + self.target_user_cache[target_type].keys() + ) + + logger.info(f"Deleting {to_del}") for d in to_del: self.del_subscribe(**d) def get_sub_category(self, target_type, target, user_type, user): - return self.target_user_cat_cache[target_type][target]['{}-{}'.format(user_type, user)] + return self.target_user_cat_cache[target_type][target][ + "{}-{}".format(user_type, user) + ] def get_sub_tags(self, target_type, target, user_type, user): - return self.target_user_tag_cache[target_type][target]['{}-{}'.format(user_type, user)] + return self.target_user_tag_cache[target_type][target][ + "{}-{}".format(user_type, user) + ] def get_next_target(self, target_type): # FIXME 插入或删除target后对队列的影响(但是并不是大问题 @@ -158,25 +217,27 @@ class Config(metaclass=Singleton): self.next_index[target_type] += 1 return res + def start_up(): config = Config() - if not (search_res := config.kv_config.search(Query().name=="version")): + if not (search_res := config.kv_config.search(Query().name == "version")): config.kv_config.insert({"name": "version", "value": config.migrate_version}) elif search_res[0].get("value") < config.migrate_version: query = Query() - version_query = (query.name == 'version') + version_query = query.name == "version" cur_version = search_res[0].get("value") if cur_version == 1: cur_version = 2 for user_conf in config.user_target.all(): conf_id = user_conf.doc_id - subs = user_conf['subs'] + subs = user_conf["subs"] for sub in subs: - sub['cats'] = [] - sub['tags'] = [] - config.user_target.update({'subs': subs}, doc_ids=[conf_id]) + sub["cats"] = [] + sub["tags"] = [] + config.user_target.update({"subs": subs}, doc_ids=[conf_id]) config.kv_config.update({"value": config.migrate_version}, version_query) # do migration config.update_send_cache() + nonebot.get_driver().on_startup(start_up) diff --git a/src/plugins/nonebot_bison/config_manager.py b/src/plugins/nonebot_bison/config_manager.py index 436c0c6..8e96ea6 100644 --- a/src/plugins/nonebot_bison/config_manager.py +++ b/src/plugins/nonebot_bison/config_manager.py @@ -7,7 +7,7 @@ from nonebot.adapters.cqhttp import Bot, Event, GroupMessageEvent from nonebot.adapters.cqhttp.message import Message from nonebot.adapters.cqhttp.permission import GROUP_ADMIN, GROUP_MEMBER, GROUP_OWNER from nonebot.matcher import Matcher -from nonebot.permission import Permission, SUPERUSER +from nonebot.permission import SUPERUSER, Permission from nonebot.rule import to_me from nonebot.typing import T_State @@ -16,137 +16,184 @@ from .platform import check_sub_target, platform_manager from .types import Target from .utils import parse_text + def _gen_prompt_template(prompt: str): - if hasattr(Message, 'template'): + if hasattr(Message, "template"): return Message.template(prompt) return prompt -common_platform = [p.platform_name for p in \ - filter(lambda platform: platform.enabled and platform.is_common, - platform_manager.values()) - ] -help_match = on_command('help', rule=to_me(), priority=5) +common_platform = [ + p.platform_name + for p in filter( + lambda platform: platform.enabled and platform.is_common, + platform_manager.values(), + ) +] + +help_match = on_command("help", rule=to_me(), priority=5) + + @help_match.handle() async def send_help(bot: Bot, event: Event, state: T_State): - message = '使用方法:\n@bot 添加订阅(仅管理员)\n@bot 查询订阅\n@bot 删除订阅(仅管理员)' + message = "使用方法:\n@bot 添加订阅(仅管理员)\n@bot 查询订阅\n@bot 删除订阅(仅管理员)" await help_match.finish(Message(await parse_text(message))) def do_add_sub(add_sub: Type[Matcher]): @add_sub.handle() async def init_promote(bot: Bot, event: Event, state: T_State): - state['_prompt'] = '请输入想要订阅的平台,目前支持,请输入冒号左边的名称:\n' + \ - ''.join(['{}:{}\n'.format(platform_name, platform_manager[platform_name].name) \ - for platform_name in common_platform]) + \ - '要查看全部平台请输入:“全部”' + state["_prompt"] = ( + "请输入想要订阅的平台,目前支持,请输入冒号左边的名称:\n" + + "".join( + [ + "{}:{}\n".format( + platform_name, platform_manager[platform_name].name + ) + for platform_name in common_platform + ] + ) + + "要查看全部平台请输入:“全部”" + ) - async def parse_platform(bot: AbstractBot, event: AbstractEvent, state: T_State) -> None: + async def parse_platform( + bot: AbstractBot, event: AbstractEvent, state: T_State + ) -> None: platform = str(event.get_message()).strip() - if platform == '全部': - message = '全部平台\n' + \ - '\n'.join(['{}:{}'.format(platform_name, platform.name) \ - for platform_name, platform in platform_manager.items()]) + if platform == "全部": + message = "全部平台\n" + "\n".join( + [ + "{}:{}".format(platform_name, platform.name) + for platform_name, platform in platform_manager.items() + ] + ) await add_sub.reject(message) elif platform in platform_manager: - state['platform'] = platform + state["platform"] = platform else: - await add_sub.reject('平台输入错误') + await add_sub.reject("平台输入错误") - @add_sub.got('platform', _gen_prompt_template('{_prompt}'), parse_platform) + @add_sub.got("platform", _gen_prompt_template("{_prompt}"), parse_platform) @add_sub.handle() async def init_id(bot: Bot, event: Event, state: T_State): - if platform_manager[state['platform']].has_target: - state['_prompt'] = '请输入订阅用户的id,详情查阅https://nonebot-bison.vercel.app/usage/#%E6%89%80%E6%94%AF%E6%8C%81%E5%B9%B3%E5%8F%B0%E7%9A%84uid' + if platform_manager[state["platform"]].has_target: + state[ + "_prompt" + ] = "请输入订阅用户的id,详情查阅https://nonebot-bison.vercel.app/usage/#%E6%89%80%E6%94%AF%E6%8C%81%E5%B9%B3%E5%8F%B0%E7%9A%84uid" else: - state['id'] = 'default' - state['name'] = await platform_manager[state['platform']].get_target_name(Target('')) + state["id"] = "default" + state["name"] = await platform_manager[state["platform"]].get_target_name( + Target("") + ) async def parse_id(bot: AbstractBot, event: AbstractEvent, state: T_State): target = str(event.get_message()).strip() try: - name = await check_sub_target(state['platform'], target) + name = await check_sub_target(state["platform"], target) if not name: - await add_sub.reject('id输入错误') - state['id'] = target - state['name'] = name + await add_sub.reject("id输入错误") + state["id"] = target + state["name"] = name except: - await add_sub.reject('id输入错误') + await add_sub.reject("id输入错误") - @add_sub.got('id', _gen_prompt_template('{_prompt}'), parse_id) + @add_sub.got("id", _gen_prompt_template("{_prompt}"), parse_id) @add_sub.handle() async def init_cat(bot: Bot, event: Event, state: T_State): - if not platform_manager[state['platform']].categories: - state['cats'] = [] + if not platform_manager[state["platform"]].categories: + state["cats"] = [] return - state['_prompt'] = '请输入要订阅的类别,以空格分隔,支持的类别有:{}'.format( - ' '.join(list(platform_manager[state['platform']].categories.values()))) + state["_prompt"] = "请输入要订阅的类别,以空格分隔,支持的类别有:{}".format( + " ".join(list(platform_manager[state["platform"]].categories.values())) + ) async def parser_cats(bot: AbstractBot, event: AbstractEvent, state: T_State): res = [] for cat in str(event.get_message()).strip().split(): - if cat not in platform_manager[state['platform']].reverse_category: - await add_sub.reject('不支持 {}'.format(cat)) - res.append(platform_manager[state['platform']].reverse_category[cat]) - state['cats'] = res + if cat not in platform_manager[state["platform"]].reverse_category: + await add_sub.reject("不支持 {}".format(cat)) + res.append(platform_manager[state["platform"]].reverse_category[cat]) + state["cats"] = res - @add_sub.got('cats', _gen_prompt_template('{_prompt}'), parser_cats) + @add_sub.got("cats", _gen_prompt_template("{_prompt}"), parser_cats) @add_sub.handle() async def init_tag(bot: Bot, event: Event, state: T_State): - if not platform_manager[state['platform']].enable_tag: - state['tags'] = [] + if not platform_manager[state["platform"]].enable_tag: + state["tags"] = [] return - state['_prompt'] = '请输入要订阅的tag,订阅所有tag输入"全部标签"' + state["_prompt"] = '请输入要订阅的tag,订阅所有tag输入"全部标签"' async def parser_tags(bot: AbstractBot, event: AbstractEvent, state: T_State): - if str(event.get_message()).strip() == '全部标签': - state['tags'] = [] + if str(event.get_message()).strip() == "全部标签": + state["tags"] = [] else: - state['tags'] = str(event.get_message()).strip().split() + state["tags"] = str(event.get_message()).strip().split() - @add_sub.got('tags', _gen_prompt_template('{_prompt}'), parser_tags) + @add_sub.got("tags", _gen_prompt_template("{_prompt}"), parser_tags) @add_sub.handle() async def add_sub_process(bot: Bot, event: Event, state: T_State): config = Config() - config.add_subscribe(state.get('_user_id') or event.group_id, user_type='group', - target=state['id'], - target_name=state['name'], target_type=state['platform'], - cats=state.get('cats', []), tags=state.get('tags', [])) - await add_sub.finish('添加 {} 成功'.format(state['name'])) + config.add_subscribe( + state.get("_user_id") or event.group_id, + user_type="group", + target=state["id"], + target_name=state["name"], + target_type=state["platform"], + cats=state.get("cats", []), + tags=state.get("tags", []), + ) + await add_sub.finish("添加 {} 成功".format(state["name"])) + def do_query_sub(query_sub: Type[Matcher]): @query_sub.handle() async def _(bot: Bot, event: Event, state: T_State): config: Config = Config() - sub_list = config.list_subscribe(state.get('_user_id') or event.group_id, "group") - res = '订阅的帐号为:\n' + sub_list = config.list_subscribe( + state.get("_user_id") or event.group_id, "group" + ) + res = "订阅的帐号为:\n" for sub in sub_list: - res += '{} {} {}'.format(sub['target_type'], sub['target_name'], sub['target']) - platform = platform_manager[sub['target_type']] + res += "{} {} {}".format( + sub["target_type"], sub["target_name"], sub["target"] + ) + platform = platform_manager[sub["target_type"]] if platform.categories: - res += ' [{}]'.format(', '.join(map(lambda x: platform.categories[x], sub['cats']))) + res += " [{}]".format( + ", ".join(map(lambda x: platform.categories[x], sub["cats"])) + ) if platform.enable_tag: - res += ' {}'.format(', '.join(sub['tags'])) - res += '\n' + res += " {}".format(", ".join(sub["tags"])) + res += "\n" await query_sub.finish(Message(await parse_text(res))) + def do_del_sub(del_sub: Type[Matcher]): @del_sub.handle() async def send_list(bot: Bot, event: Event, state: T_State): config: Config = Config() - sub_list = config.list_subscribe(state.get('_user_id') or event.group_id, "group") - res = '订阅的帐号为:\n' - state['sub_table'] = {} + sub_list = config.list_subscribe( + state.get("_user_id") or event.group_id, "group" + ) + res = "订阅的帐号为:\n" + state["sub_table"] = {} for index, sub in enumerate(sub_list, 1): - state['sub_table'][index] = {'target_type': sub['target_type'], 'target': sub['target']} - res += '{} {} {} {}\n'.format(index, sub['target_type'], sub['target_name'], sub['target']) - platform = platform_manager[sub['target_type']] + state["sub_table"][index] = { + "target_type": sub["target_type"], + "target": sub["target"], + } + res += "{} {} {} {}\n".format( + index, sub["target_type"], sub["target_name"], sub["target"] + ) + platform = platform_manager[sub["target_type"]] if platform.categories: - res += ' [{}]'.format(', '.join(map(lambda x: platform.categories[x], sub['cats']))) + res += " [{}]".format( + ", ".join(map(lambda x: platform.categories[x], sub["cats"])) + ) if platform.enable_tag: - res += ' {}'.format(', '.join(sub['tags'])) - res += '\n' - res += '请输入要删除的订阅的序号' + res += " {}".format(", ".join(sub["tags"])) + res += "\n" + res += "请输入要删除的订阅的序号" await bot.send(event=event, message=Message(await parse_text(res))) @del_sub.receive() @@ -154,39 +201,60 @@ def do_del_sub(del_sub: Type[Matcher]): try: index = int(str(event.get_message()).strip()) config = Config() - config.del_subscribe(state.get('_user_id') or event.group_id, 'group', **state['sub_table'][index]) + config.del_subscribe( + state.get("_user_id") or event.group_id, + "group", + **state["sub_table"][index] + ) except Exception as e: - await del_sub.reject('删除错误') + await del_sub.reject("删除错误") logger.warning(e) else: - await del_sub.finish('删除成功') + await del_sub.finish("删除成功") + async def parse_group_number(bot: AbstractBot, event: AbstractEvent, state: T_State): state[state["_current_key"]] = int(str(event.get_message())) -add_sub_matcher = on_command("添加订阅", rule=to_me(), permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, priority=5) +add_sub_matcher = on_command( + "添加订阅", rule=to_me(), permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, priority=5 +) do_add_sub(add_sub_matcher) -manage_add_sub_mather = on_command('管理-添加订阅', permission=SUPERUSER, priority=5) -@manage_add_sub_mather.got('_user_id', "群号", parse_group_number) +manage_add_sub_mather = on_command("管理-添加订阅", permission=SUPERUSER, priority=5) + + +@manage_add_sub_mather.got("_user_id", "群号", parse_group_number) async def handle(bot: Bot, event: Event, state: T_State): pass + + do_add_sub(manage_add_sub_mather) query_sub_macher = on_command("查询订阅", rule=to_me(), priority=5) do_query_sub(query_sub_macher) -manage_query_sub_mather = on_command('管理-查询订阅', permission=SUPERUSER, priority=5) -@manage_query_sub_mather.got('_user_id', "群号", parse_group_number) +manage_query_sub_mather = on_command("管理-查询订阅", permission=SUPERUSER, priority=5) + + +@manage_query_sub_mather.got("_user_id", "群号", parse_group_number) async def handle(bot: Bot, event: Event, state: T_State): pass + + do_query_sub(manage_query_sub_mather) -del_sub_macher = on_command("删除订阅", rule=to_me(), permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, priority=5) +del_sub_macher = on_command( + "删除订阅", rule=to_me(), permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, priority=5 +) do_del_sub(del_sub_macher) -manage_del_sub_mather = on_command('管理-删除订阅', permission=SUPERUSER, priority=5) -@manage_del_sub_mather.got('_user_id', "群号", parse_group_number) +manage_del_sub_mather = on_command("管理-删除订阅", permission=SUPERUSER, priority=5) + + +@manage_del_sub_mather.got("_user_id", "群号", parse_group_number) async def handle(bot: Bot, event: Event, state: T_State): pass + + do_del_sub(manage_del_sub_mather) diff --git a/src/plugins/nonebot_bison/platform/__init__.py b/src/plugins/nonebot_bison/platform/__init__.py index 028f002..60b5e32 100644 --- a/src/plugins/nonebot_bison/platform/__init__.py +++ b/src/plugins/nonebot_bison/platform/__init__.py @@ -1,18 +1,19 @@ from collections import defaultdict - -from .platform import Platform, NoTargetGroup -from pkgutil import iter_modules -from pathlib import Path from importlib import import_module +from pathlib import Path +from pkgutil import iter_modules + +from .platform import NoTargetGroup, Platform _package_dir = str(Path(__file__).resolve().parent) for (_, module_name, _) in iter_modules([_package_dir]): - import_module(f'{__name__}.{module_name}') + import_module(f"{__name__}.{module_name}") async def check_sub_target(target_type, target): return await platform_manager[target_type].get_target_name(target) + _platform_list = defaultdict(list) for _platform in Platform.registry: if not _platform.enabled: @@ -24,5 +25,6 @@ for name, platform_list in _platform_list.items(): if len(platform_list) == 1: platform_manager[name] = platform_list[0]() else: - platform_manager[name] = NoTargetGroup([_platform() for _platform in platform_list]) - + platform_manager[name] = NoTargetGroup( + [_platform() for _platform in platform_list] + ) diff --git a/src/plugins/nonebot_bison/platform/arknights.py b/src/plugins/nonebot_bison/platform/arknights.py index fb25729..f7b0002 100644 --- a/src/plugins/nonebot_bison/platform/arknights.py +++ b/src/plugins/nonebot_bison/platform/arknights.py @@ -1,8 +1,8 @@ import json from typing import Any -from bs4 import BeautifulSoup as bs import httpx +from bs4 import BeautifulSoup as bs from ..post import Post from ..types import Category, RawPost, Target @@ -12,26 +12,28 @@ from .platform import CategoryNotSupport, NewMessage, StatusChange class Arknights(NewMessage): - categories = {1: '游戏公告'} - platform_name = 'arknights' - name = '明日方舟游戏信息' + categories = {1: "游戏公告"} + platform_name = "arknights" + name = "明日方舟游戏信息" enable_tag = False enabled = True is_common = False - schedule_type = 'interval' - schedule_kw = {'seconds': 30} + schedule_type = "interval" + schedule_kw = {"seconds": 30} has_target = False async def get_target_name(self, _: Target) -> str: - return '明日方舟游戏信息' + return "明日方舟游戏信息" async def get_sub_list(self, _) -> list[RawPost]: async with httpx.AsyncClient() as client: - raw_data = await client.get('https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/announcement.meta.json') - return json.loads(raw_data.text)['announceList'] + raw_data = await client.get( + "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/announcement.meta.json" + ) + return json.loads(raw_data.text)["announceList"] def get_id(self, post: RawPost) -> Any: - return post['announceId'] + return post["announceId"] def get_date(self, _: RawPost) -> None: return None @@ -40,64 +42,85 @@ class Arknights(NewMessage): return Category(1) async def parse(self, raw_post: RawPost) -> Post: - announce_url = raw_post['webUrl'] - text = '' + announce_url = raw_post["webUrl"] + text = "" async with httpx.AsyncClient() as client: raw_html = await client.get(announce_url) - soup = bs(raw_html, 'html.parser') + soup = bs(raw_html, "html.parser") pics = [] if soup.find("div", class_="standerd-container"): # 图文 render = Render() - viewport = {'width': 320, 'height': 6400, 'deviceScaleFactor': 3} - pic_data = await render.render(announce_url, viewport=viewport, target='div.main') + viewport = {"width": 320, "height": 6400, "deviceScaleFactor": 3} + pic_data = await render.render( + announce_url, viewport=viewport, target="div.main" + ) if pic_data: pics.append(pic_data) else: - text = '图片渲染失败' - elif (pic := soup.find('img', class_='banner-image')): - pics.append(pic['src']) + text = "图片渲染失败" + elif pic := soup.find("img", class_="banner-image"): + pics.append(pic["src"]) else: raise CategoryNotSupport() - return Post('arknights', text=text, url='', target_name="明日方舟游戏内公告", pics=pics, compress=True, override_use_pic=False) + return Post( + "arknights", + text=text, + url="", + target_name="明日方舟游戏内公告", + pics=pics, + compress=True, + override_use_pic=False, + ) + class AkVersion(StatusChange): - categories = {2: '更新信息'} - platform_name = 'arknights' - name = '明日方舟游戏信息' + categories = {2: "更新信息"} + platform_name = "arknights" + name = "明日方舟游戏信息" enable_tag = False enabled = True is_common = False - schedule_type = 'interval' - schedule_kw = {'seconds': 30} + schedule_type = "interval" + schedule_kw = {"seconds": 30} has_target = False async def get_target_name(self, _: Target) -> str: - return '明日方舟游戏信息' + return "明日方舟游戏信息" async def get_status(self, _): async with httpx.AsyncClient() as client: - res_ver = await client.get('https://ak-conf.hypergryph.com/config/prod/official/IOS/version') - res_preanounce = await client.get('https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json') + res_ver = await client.get( + "https://ak-conf.hypergryph.com/config/prod/official/IOS/version" + ) + res_preanounce = await client.get( + "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json" + ) res = res_ver.json() res.update(res_preanounce.json()) return res def compare_status(self, _, old_status, new_status): res = [] - if old_status.get('preAnnounceType') == 2 and new_status.get('preAnnounceType') == 0: - res.append(Post('arknights', - text='登录界面维护公告上线(大概是开始维护了)', - target_name='明日方舟更新信息')) - elif old_status.get('preAnnounceType') == 0 and new_status.get('preAnnounceType') == 2: - res.append(Post('arknights', - text='登录界面维护公告下线(大概是开服了,冲!)', - target_name='明日方舟更新信息')) - if old_status.get('clientVersion') != new_status.get('clientVersion'): - res.append(Post('arknights', text='游戏本体更新(大更新)', target_name='明日方舟更新信息')) - if old_status.get('resVersion') != new_status.get('resVersion'): - res.append(Post('arknights', text='游戏资源更新(小更新)', target_name='明日方舟更新信息')) + if ( + old_status.get("preAnnounceType") == 2 + and new_status.get("preAnnounceType") == 0 + ): + res.append( + Post("arknights", text="登录界面维护公告上线(大概是开始维护了)", target_name="明日方舟更新信息") + ) + elif ( + old_status.get("preAnnounceType") == 0 + and new_status.get("preAnnounceType") == 2 + ): + res.append( + Post("arknights", text="登录界面维护公告下线(大概是开服了,冲!)", target_name="明日方舟更新信息") + ) + if old_status.get("clientVersion") != new_status.get("clientVersion"): + res.append(Post("arknights", text="游戏本体更新(大更新)", target_name="明日方舟更新信息")) + if old_status.get("resVersion") != new_status.get("resVersion"): + res.append(Post("arknights", text="游戏资源更新(小更新)", target_name="明日方舟更新信息")) return res def get_category(self, _): @@ -106,28 +129,29 @@ class AkVersion(StatusChange): async def parse(self, raw_post): return raw_post + class MonsterSiren(NewMessage): - categories = {3: '塞壬唱片新闻'} - platform_name = 'arknights' - name = '明日方舟游戏信息' + categories = {3: "塞壬唱片新闻"} + platform_name = "arknights" + name = "明日方舟游戏信息" enable_tag = False enabled = True is_common = False - schedule_type = 'interval' - schedule_kw = {'seconds': 30} + schedule_type = "interval" + schedule_kw = {"seconds": 30} has_target = False async def get_target_name(self, _: Target) -> str: - return '明日方舟游戏信息' + return "明日方舟游戏信息" async def get_sub_list(self, _) -> list[RawPost]: async with httpx.AsyncClient() as client: - raw_data = await client.get('https://monster-siren.hypergryph.com/api/news') - return raw_data.json()['data']['list'] + raw_data = await client.get("https://monster-siren.hypergryph.com/api/news") + return raw_data.json()["data"]["list"] def get_id(self, post: RawPost) -> Any: - return post['cid'] + return post["cid"] def get_date(self, _) -> None: return None @@ -138,13 +162,21 @@ class MonsterSiren(NewMessage): async def parse(self, raw_post: RawPost) -> Post: url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}' async with httpx.AsyncClient() as client: - res = await client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}') + res = await client.get( + f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}' + ) raw_data = res.json() - content = raw_data['data']['content'] - content = content.replace('
', '\n') - soup = bs(content, 'html.parser') - imgs = list(map(lambda x: x['src'], soup('img'))) + content = raw_data["data"]["content"] + content = content.replace("", "\n") + soup = bs(content, "html.parser") + imgs = list(map(lambda x: x["src"], soup("img"))) text = f'{raw_post["title"]}\n{soup.text.strip()}' - return Post('monster-siren', text=text, pics=imgs, - url=url, target_name="塞壬唱片新闻", compress=True, - override_use_pic=False) + return Post( + "monster-siren", + text=text, + pics=imgs, + url=url, + target_name="塞壬唱片新闻", + compress=True, + override_use_pic=False, + ) diff --git a/src/plugins/nonebot_bison/platform/bilibili.py b/src/plugins/nonebot_bison/platform/bilibili.py index 8980eb1..5ff3cc1 100644 --- a/src/plugins/nonebot_bison/platform/bilibili.py +++ b/src/plugins/nonebot_bison/platform/bilibili.py @@ -5,50 +5,57 @@ import httpx from ..post import Post from ..types import Category, RawPost, Tag, Target -from .platform import NewMessage, CategoryNotSupport +from .platform import CategoryNotSupport, NewMessage + class Bilibili(NewMessage): categories = { - 1: "一般动态", - 2: "专栏文章", - 3: "视频", - 4: "纯文字", - 5: "转发" - # 5: "短视频" - } - platform_name = 'bilibili' + 1: "一般动态", + 2: "专栏文章", + 3: "视频", + 4: "纯文字", + 5: "转发" + # 5: "短视频" + } + platform_name = "bilibili" enable_tag = True enabled = True is_common = True - schedule_type = 'interval' - schedule_kw = {'seconds': 10} - name = 'B站' + schedule_type = "interval" + schedule_kw = {"seconds": 10} + name = "B站" has_target = True async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: - res = await client.get('https://api.bilibili.com/x/space/acc/info', params={'mid': target}) + res = await client.get( + "https://api.bilibili.com/x/space/acc/info", params={"mid": target} + ) res_data = json.loads(res.text) - if res_data['code']: + if res_data["code"]: return None - return res_data['data']['name'] + return res_data["data"]["name"] async def get_sub_list(self, target: Target) -> list[RawPost]: async with httpx.AsyncClient() as client: - params = {'host_uid': target, 'offset': 0, 'need_top': 0} - res = await client.get('https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history', params=params, timeout=4.0) + params = {"host_uid": target, "offset": 0, "need_top": 0} + res = await client.get( + "https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history", + params=params, + timeout=4.0, + ) res_dict = json.loads(res.text) - if res_dict['code'] == 0: - return res_dict['data']['cards'] + if res_dict["code"] == 0: + return res_dict["data"]["cards"] else: return [] def get_id(self, post: RawPost) -> Any: - return post['desc']['dynamic_id'] - + return post["desc"]["dynamic_id"] + def get_date(self, post: RawPost) -> int: - return post['desc']['timestamp'] + return post["desc"]["timestamp"] def _do_get_category(self, post_type: int) -> Category: if post_type == 2: @@ -65,63 +72,75 @@ class Bilibili(NewMessage): raise CategoryNotSupport() def get_category(self, post: RawPost) -> Category: - post_type = post['desc']['type'] + post_type = post["desc"]["type"] return self._do_get_category(post_type) def get_tags(self, raw_post: RawPost) -> list[Tag]: - return [*map(lambda tp: tp['topic_name'], raw_post['display']['topic_info']['topic_details'])] + return [ + *map( + lambda tp: tp["topic_name"], + raw_post["display"]["topic_info"]["topic_details"], + ) + ] def _get_info(self, post_type: Category, card) -> tuple[str, list]: if post_type == 1: # 一般动态 - text = card['item']['description'] - pic = [img['img_src'] for img in card['item']['pictures']] + text = card["item"]["description"] + pic = [img["img_src"] for img in card["item"]["pictures"]] elif post_type == 2: # 专栏文章 - text = '{} {}'.format(card['title'], card['summary']) - pic = card['image_urls'] + text = "{} {}".format(card["title"], card["summary"]) + pic = card["image_urls"] elif post_type == 3: # 视频 - text = card['dynamic'] - pic = [card['pic']] + text = card["dynamic"] + pic = [card["pic"]] elif post_type == 4: # 纯文字 - text = card['item']['content'] + text = card["item"]["content"] pic = [] else: raise CategoryNotSupport() return text, pic async def parse(self, raw_post: RawPost) -> Post: - card_content = json.loads(raw_post['card']) + card_content = json.loads(raw_post["card"]) post_type = self.get_category(raw_post) - target_name = raw_post['desc']['user_profile']['info']['uname'] + target_name = raw_post["desc"]["user_profile"]["info"]["uname"] if post_type >= 1 and post_type < 5: - url = '' + url = "" if post_type == 1: # 一般动态 - url = 'https://t.bilibili.com/{}'.format(raw_post['desc']['dynamic_id_str']) + url = "https://t.bilibili.com/{}".format( + raw_post["desc"]["dynamic_id_str"] + ) elif post_type == 2: # 专栏文章 - url = 'https://www.bilibili.com/read/cv{}'.format(raw_post['desc']['rid']) + url = "https://www.bilibili.com/read/cv{}".format( + raw_post["desc"]["rid"] + ) elif post_type == 3: # 视频 - url = 'https://www.bilibili.com/video/{}'.format(raw_post['desc']['bvid']) + url = "https://www.bilibili.com/video/{}".format( + raw_post["desc"]["bvid"] + ) elif post_type == 4: # 纯文字 - url = 'https://t.bilibili.com/{}'.format(raw_post['desc']['dynamic_id_str']) + url = "https://t.bilibili.com/{}".format( + raw_post["desc"]["dynamic_id_str"] + ) text, pic = self._get_info(post_type, card_content) elif post_type == 5: # 转发 - url = 'https://t.bilibili.com/{}'.format(raw_post['desc']['dynamic_id_str']) - text = card_content['item']['content'] - orig_type = card_content['item']['orig_type'] - orig = json.loads(card_content['origin']) + url = "https://t.bilibili.com/{}".format(raw_post["desc"]["dynamic_id_str"]) + text = card_content["item"]["content"] + orig_type = card_content["item"]["orig_type"] + orig = json.loads(card_content["origin"]) orig_text, _ = self._get_info(self._do_get_category(orig_type), orig) - text += '\n--------------\n' + text += "\n--------------\n" text += orig_text pic = [] else: raise CategoryNotSupport(post_type) - return Post('bilibili', text=text, url=url, pics=pic, target_name=target_name) - + return Post("bilibili", text=text, url=url, pics=pic, target_name=target_name) diff --git a/src/plugins/nonebot_bison/platform/ncm_artist.py b/src/plugins/nonebot_bison/platform/ncm_artist.py index e230f65..a30072f 100644 --- a/src/plugins/nonebot_bison/platform/ncm_artist.py +++ b/src/plugins/nonebot_bison/platform/ncm_artist.py @@ -1,54 +1,58 @@ from typing import Any, Optional import httpx + from ..post import Post from ..types import RawPost, Target from .platform import NewMessage + class NcmArtist(NewMessage): categories = {} - platform_name = 'ncm-artist' + platform_name = "ncm-artist" enable_tag = False enabled = True is_common = True - schedule_type = 'interval' - schedule_kw = {'minutes': 1} + schedule_type = "interval" + schedule_kw = {"minutes": 1} name = "网易云-歌手" has_target = True async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: res = await client.get( - "https://music.163.com/api/artist/albums/{}".format(target), - headers={'Referer': 'https://music.163.com/'} - ) + "https://music.163.com/api/artist/albums/{}".format(target), + headers={"Referer": "https://music.163.com/"}, + ) res_data = res.json() - if res_data['code'] != 200: + if res_data["code"] != 200: return - return res_data['artist']['name'] + return res_data["artist"]["name"] async def get_sub_list(self, target: Target) -> list[RawPost]: async with httpx.AsyncClient() as client: res = await client.get( - "https://music.163.com/api/artist/albums/{}".format(target), - headers={'Referer': 'https://music.163.com/'} - ) + "https://music.163.com/api/artist/albums/{}".format(target), + headers={"Referer": "https://music.163.com/"}, + ) res_data = res.json() - if res_data['code'] != 200: + if res_data["code"] != 200: return [] else: - return res_data['hotAlbums'] + return res_data["hotAlbums"] def get_id(self, post: RawPost) -> Any: - return post['id'] + return post["id"] def get_date(self, post: RawPost) -> int: - return post['publishTime'] // 1000 + return post["publishTime"] // 1000 async def parse(self, raw_post: RawPost) -> Post: - text = '新专辑发布:{}'.format(raw_post['name']) - target_name = raw_post['artist']['name'] - pics = [raw_post['picUrl']] - url = "https://music.163.com/#/album?id={}".format(raw_post['id']) - return Post('ncm-artist', text=text, url=url, pics=pics, target_name=target_name) + text = "新专辑发布:{}".format(raw_post["name"]) + target_name = raw_post["artist"]["name"] + pics = [raw_post["picUrl"]] + url = "https://music.163.com/#/album?id={}".format(raw_post["id"]) + return Post( + "ncm-artist", text=text, url=url, pics=pics, target_name=target_name + ) diff --git a/src/plugins/nonebot_bison/platform/ncm_radio.py b/src/plugins/nonebot_bison/platform/ncm_radio.py index 6fae725..20abb52 100644 --- a/src/plugins/nonebot_bison/platform/ncm_radio.py +++ b/src/plugins/nonebot_bison/platform/ncm_radio.py @@ -1,56 +1,58 @@ from typing import Any, Optional import httpx + from ..post import Post from ..types import RawPost, Target from .platform import NewMessage + class NcmRadio(NewMessage): categories = {} - platform_name = 'ncm-radio' + platform_name = "ncm-radio" enable_tag = False enabled = True is_common = False - schedule_type = 'interval' - schedule_kw = {'minutes': 10} + schedule_type = "interval" + schedule_kw = {"minutes": 10} name = "网易云-电台" has_target = True async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: res = await client.post( - "http://music.163.com/api/dj/program/byradio", - headers={'Referer': 'https://music.163.com/'}, - data={"radioId": target, "limit": 1000, "offset": 0} - ) + "http://music.163.com/api/dj/program/byradio", + headers={"Referer": "https://music.163.com/"}, + data={"radioId": target, "limit": 1000, "offset": 0}, + ) res_data = res.json() - if res_data['code'] != 200 or res_data['programs'] == 0: + if res_data["code"] != 200 or res_data["programs"] == 0: return - return res_data['programs'][0]['radio']['name'] + return res_data["programs"][0]["radio"]["name"] async def get_sub_list(self, target: Target) -> list[RawPost]: async with httpx.AsyncClient() as client: res = await client.post( - "http://music.163.com/api/dj/program/byradio", - headers={'Referer': 'https://music.163.com/'}, - data={"radioId": target, "limit": 1000, "offset": 0} - ) + "http://music.163.com/api/dj/program/byradio", + headers={"Referer": "https://music.163.com/"}, + data={"radioId": target, "limit": 1000, "offset": 0}, + ) res_data = res.json() - if res_data['code'] != 200: + if res_data["code"] != 200: return [] else: - return res_data['programs'] + return res_data["programs"] def get_id(self, post: RawPost) -> Any: - return post['id'] + return post["id"] def get_date(self, post: RawPost) -> int: - return post['createTime'] // 1000 + return post["createTime"] // 1000 async def parse(self, raw_post: RawPost) -> Post: - text = '网易云电台更新:{}'.format(raw_post['name']) - target_name = raw_post['radio']['name'] - pics = [raw_post['coverUrl']] - url = "https://music.163.com/#/program/{}".format(raw_post['id']) - return Post('ncm-radio', text=text, url=url, pics=pics, target_name=target_name) + text = "网易云电台更新:{}".format(raw_post["name"]) + target_name = raw_post["radio"]["name"] + pics = [raw_post["coverUrl"]] + url = "https://music.163.com/#/program/{}".format(raw_post["id"]) + return Post("ncm-radio", text=text, url=url, pics=pics, target_name=target_name) diff --git a/src/plugins/nonebot_bison/platform/platform.py b/src/plugins/nonebot_bison/platform/platform.py index cff3ff6..1bec3d8 100644 --- a/src/plugins/nonebot_bison/platform/platform.py +++ b/src/plugins/nonebot_bison/platform/platform.py @@ -1,8 +1,8 @@ -from abc import abstractmethod, ABC +import time +from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass -import time -from typing import Any, Collection, Optional, Literal +from typing import Any, Collection, Literal, Optional import httpx from nonebot import logger @@ -17,26 +17,27 @@ class CategoryNotSupport(Exception): class RegistryMeta(type): - def __new__(cls, name, bases, namespace, **kwargs): return super().__new__(cls, name, bases, namespace) def __init__(cls, name, bases, namespace, **kwargs): - if kwargs.get('base'): + if kwargs.get("base"): # this is the base class cls.registry = [] - elif not kwargs.get('abstract'): + elif not kwargs.get("abstract"): # this is the subclass cls.registry.append(cls) super().__init__(name, bases, namespace, **kwargs) + class RegistryABCMeta(RegistryMeta, ABC): ... + class Platform(metaclass=RegistryABCMeta, base=True): - - schedule_type: Literal['date', 'interval', 'cron'] + + schedule_type: Literal["date", "interval", "cron"] schedule_kw: dict is_common: bool enabled: bool @@ -52,7 +53,9 @@ class Platform(metaclass=RegistryABCMeta, base=True): ... @abstractmethod - async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: + async def fetch_new_post( + self, target: Target, users: list[UserSubInfo] + ) -> list[tuple[User, list[Post]]]: ... @abstractmethod @@ -67,7 +70,7 @@ class Platform(metaclass=RegistryABCMeta, base=True): super().__init__() self.reverse_category = {} for key, val in self.categories.items(): - self.reverse_category[val] = key + self.reverse_category[val] = key self.store = dict() @abstractmethod @@ -75,12 +78,14 @@ class Platform(metaclass=RegistryABCMeta, base=True): "Return Tag list of given RawPost" def get_stored_data(self, target: Target) -> Any: - return self.store.get(target) + return self.store.get(target) def set_stored_data(self, target: Target, data: Any): self.store[target] = data - async def filter_user_custom(self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]) -> list[RawPost]: + async def filter_user_custom( + self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag] + ) -> list[RawPost]: res: list[RawPost] = [] for raw_post in raw_post_list: if self.categories: @@ -99,12 +104,16 @@ class Platform(metaclass=RegistryABCMeta, base=True): res.append(raw_post) return res - async def dispatch_user_post(self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: + async def dispatch_user_post( + self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo] + ) -> list[tuple[User, list[Post]]]: res: list[tuple[User, list[Post]]] = [] for user, category_getter, tag_getter in users: required_tags = tag_getter(target) if self.enable_tag else [] cats = category_getter(target) - user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags) + user_raw_post = await self.filter_user_custom( + new_posts, cats, required_tags + ) user_post: list[Post] = [] for raw_post in user_raw_post: user_post.append(await self.do_parse(raw_post)) @@ -116,6 +125,7 @@ class Platform(metaclass=RegistryABCMeta, base=True): "Return category of given Rawpost" raise NotImplementedError() + class MessageProcess(Platform, abstract=True): "General message process fetch, parse, filter progress" @@ -127,7 +137,6 @@ class MessageProcess(Platform, abstract=True): def get_id(self, post: RawPost) -> Any: "Get post id of given RawPost" - async def do_parse(self, raw_post: RawPost) -> Post: post_id = self.get_id(raw_post) if post_id not in self.parse_cache: @@ -156,8 +165,11 @@ class MessageProcess(Platform, abstract=True): # post_id = self.get_id(raw_post) # if post_id in exists_posts_set: # continue - if (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60 and \ - plugin_config.bison_init_filter: + if ( + (post_time := self.get_date(raw_post)) + and time.time() - post_time > 2 * 60 * 60 + and plugin_config.bison_init_filter + ): continue try: self.get_category(raw_post) @@ -168,15 +180,18 @@ class MessageProcess(Platform, abstract=True): res.append(raw_post) return res + class NewMessage(MessageProcess, abstract=True): "Fetch a list of messages, filter the new messages, dispatch it to different users" @dataclass - class MessageStorage(): + class MessageStorage: inited: bool exists_posts: set[Any] - async def filter_common_with_diff(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]: + async def filter_common_with_diff( + self, target: Target, raw_post_list: list[RawPost] + ) -> list[RawPost]: filtered_post = await self.filter_common(raw_post_list) store = self.get_stored_data(target) or self.MessageStorage(False, set()) res = [] @@ -185,7 +200,11 @@ class NewMessage(MessageProcess, abstract=True): for raw_post in filtered_post: post_id = self.get_id(raw_post) store.exists_posts.add(post_id) - logger.info('init {}-{} with {}'.format(self.platform_name, target, store.exists_posts)) + logger.info( + "init {}-{} with {}".format( + self.platform_name, target, store.exists_posts + ) + ) store.inited = True else: for raw_post in filtered_post: @@ -197,8 +216,9 @@ class NewMessage(MessageProcess, abstract=True): self.set_stored_data(target, store) return res - - async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: + async def fetch_new_post( + self, target: Target, users: list[UserSubInfo] + ) -> list[tuple[User, list[Post]]]: try: post_list = await self.get_sub_list(target) new_posts = await self.filter_common_with_diff(target, post_list) @@ -206,17 +226,25 @@ class NewMessage(MessageProcess, abstract=True): return [] else: for post in new_posts: - logger.info('fetch new post from {} {}: {}'.format( - self.platform_name, - target if self.has_target else '-', - self.get_id(post))) + logger.info( + "fetch new post from {} {}: {}".format( + self.platform_name, + target if self.has_target else "-", + self.get_id(post), + ) + ) res = await self.dispatch_user_post(target, new_posts, users) self.parse_cache = {} return res except httpx.RequestError as err: - logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) + logger.warning( + "network connection error: {}, url: {}".format( + type(err), err.request.url + ) + ) return [] + class StatusChange(Platform, abstract=True): "Watch a status, and fire a post when status changes" @@ -232,49 +260,69 @@ class StatusChange(Platform, abstract=True): async def parse(self, raw_post: RawPost) -> Post: ... - async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: + async def fetch_new_post( + self, target: Target, users: list[UserSubInfo] + ) -> list[tuple[User, list[Post]]]: try: new_status = await self.get_status(target) res = [] if old_status := self.get_stored_data(target): diff = self.compare_status(target, old_status, new_status) if diff: - logger.info("status changes {} {}: {} -> {}".format( - self.platform_name, - target if self.has_target else '-', - old_status, new_status - )) + logger.info( + "status changes {} {}: {} -> {}".format( + self.platform_name, + target if self.has_target else "-", + old_status, + new_status, + ) + ) res = await self.dispatch_user_post(target, diff, users) self.set_stored_data(target, new_status) return res except httpx.RequestError as err: - logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) + logger.warning( + "network connection error: {}, url: {}".format( + type(err), err.request.url + ) + ) return [] + class SimplePost(MessageProcess, abstract=True): "Fetch a list of messages, dispatch it to different users" - async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: + async def fetch_new_post( + self, target: Target, users: list[UserSubInfo] + ) -> list[tuple[User, list[Post]]]: try: new_posts = await self.get_sub_list(target) if not new_posts: return [] else: for post in new_posts: - logger.info('fetch new post from {} {}: {}'.format( - self.platform_name, - target if self.has_target else '-', - self.get_id(post))) + logger.info( + "fetch new post from {} {}: {}".format( + self.platform_name, + target if self.has_target else "-", + self.get_id(post), + ) + ) res = await self.dispatch_user_post(target, new_posts, users) self.parse_cache = {} return res except httpx.RequestError as err: - logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) + logger.warning( + "network connection error: {}, url: {}".format( + type(err), err.request.url + ) + ) return [] + class NoTargetGroup(Platform, abstract=True): enable_tag = False - DUMMY_STR = '_DUMMY' + DUMMY_STR = "_DUMMY" enabled = True has_target = False @@ -287,24 +335,35 @@ class NoTargetGroup(Platform, abstract=True): self.schedule_kw = platform_list[0].schedule_kw for platform in platform_list: if platform.has_target: - raise RuntimeError('Platform {} should have no target'.format(platform.name)) + raise RuntimeError( + "Platform {} should have no target".format(platform.name) + ) if name == self.DUMMY_STR: name = platform.name elif name != platform.name: - raise RuntimeError('Platform name for {} not fit'.format(self.platform_name)) + raise RuntimeError( + "Platform name for {} not fit".format(self.platform_name) + ) platform_category_key_set = set(platform.categories.keys()) if platform_category_key_set & categories_keys: - raise RuntimeError('Platform categories for {} duplicate'.format(self.platform_name)) + raise RuntimeError( + "Platform categories for {} duplicate".format(self.platform_name) + ) categories_keys |= platform_category_key_set self.categories.update(platform.categories) - if platform.schedule_kw != self.schedule_kw or platform.schedule_type != self.schedule_type: - raise RuntimeError('Platform scheduler for {} not fit'.format(self.platform_name)) + if ( + platform.schedule_kw != self.schedule_kw + or platform.schedule_type != self.schedule_type + ): + raise RuntimeError( + "Platform scheduler for {} not fit".format(self.platform_name) + ) self.name = name self.is_common = platform_list[0].is_common super().__init__() def __str__(self): - return '[' + ' '.join(map(lambda x: x.name, self.platform_list)) + ']' + return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]" async def get_target_name(self, _): return await self.platform_list[0].get_target_name(_) @@ -316,4 +375,3 @@ class NoTargetGroup(Platform, abstract=True): for user, posts in platform_res: res[user].extend(posts) return [[key, val] for key, val in res.items()] - diff --git a/src/plugins/nonebot_bison/platform/rss.py b/src/plugins/nonebot_bison/platform/rss.py index 4cc18cc..330d93d 100644 --- a/src/plugins/nonebot_bison/platform/rss.py +++ b/src/plugins/nonebot_bison/platform/rss.py @@ -1,31 +1,32 @@ import calendar from typing import Any, Optional -from bs4 import BeautifulSoup as bs import feedparser import httpx +from bs4 import BeautifulSoup as bs from ..post import Post from ..types import RawPost, Target from .platform import NewMessage + class Rss(NewMessage): categories = {} enable_tag = False - platform_name = 'rss' + platform_name = "rss" name = "Rss" enabled = True is_common = True - schedule_type = 'interval' - schedule_kw = {'seconds': 30} + schedule_type = "interval" + schedule_kw = {"seconds": 30} has_target = True async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: res = await client.get(target, timeout=10.0) feed = feedparser.parse(res.text) - return feed['feed']['title'] + return feed["feed"]["title"] def get_date(self, post: RawPost) -> int: return calendar.timegm(post.published_parsed) @@ -39,12 +40,18 @@ class Rss(NewMessage): feed = feedparser.parse(res) entries = feed.entries for entry in entries: - entry['_target_name'] = feed.feed.title + entry["_target_name"] = feed.feed.title return feed.entries async def parse(self, raw_post: RawPost) -> Post: - text = raw_post.get('title', '') + '\n' if raw_post.get('title') else '' - soup = bs(raw_post.description, 'html.parser') + text = raw_post.get("title", "") + "\n" if raw_post.get("title") else "" + soup = bs(raw_post.description, "html.parser") text += soup.text.strip() - pics = list(map(lambda x: x.attrs['src'], soup('img'))) - return Post('rss', text=text, url=raw_post.link, pics=pics, target_name=raw_post['_target_name']) + pics = list(map(lambda x: x.attrs["src"], soup("img"))) + return Post( + "rss", + text=text, + url=raw_post.link, + pics=pics, + target_name=raw_post["_target_name"], + ) diff --git a/src/plugins/nonebot_bison/platform/wechat.py b/src/plugins/nonebot_bison/platform/wechat.py index 7c04306..d5f5487 100644 --- a/src/plugins/nonebot_bison/platform/wechat.py +++ b/src/plugins/nonebot_bison/platform/wechat.py @@ -1,14 +1,15 @@ -from datetime import datetime import hashlib import json import re +from datetime import datetime from typing import Any, Optional -from bs4 import BeautifulSoup as bs import httpx +from bs4 import BeautifulSoup as bs from ..post import Post from ..types import * + # from .platform import Platform @@ -75,4 +76,3 @@ from ..types import * # pics=[], # url='' # ) - diff --git a/src/plugins/nonebot_bison/platform/weibo.py b/src/plugins/nonebot_bison/platform/weibo.py index 19d8703..365e3b2 100644 --- a/src/plugins/nonebot_bison/platform/weibo.py +++ b/src/plugins/nonebot_bison/platform/weibo.py @@ -1,121 +1,152 @@ -from datetime import datetime import json import re +from datetime import datetime from typing import Any, Optional -from bs4 import BeautifulSoup as bs import httpx +from bs4 import BeautifulSoup as bs from nonebot import logger from ..post import Post from ..types import * from .platform import NewMessage + class Weibo(NewMessage): categories = { - 1: '转发', - 2: '视频', - 3: '图文', - 4: '文字', - } + 1: "转发", + 2: "视频", + 3: "图文", + 4: "文字", + } enable_tag = True - platform_name = 'weibo' - name = '新浪微博' + platform_name = "weibo" + name = "新浪微博" enabled = True is_common = True - schedule_type = 'interval' - schedule_kw = {'seconds': 3} + schedule_type = "interval" + schedule_kw = {"seconds": 3} has_target = True async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: - param = {'containerid': '100505' + target} - res = await client.get('https://m.weibo.cn/api/container/getIndex', params=param) + param = {"containerid": "100505" + target} + res = await client.get( + "https://m.weibo.cn/api/container/getIndex", params=param + ) res_dict = json.loads(res.text) - if res_dict.get('ok') == 1: - return res_dict['data']['userInfo']['screen_name'] + if res_dict.get("ok") == 1: + return res_dict["data"]["userInfo"]["screen_name"] else: return None async def get_sub_list(self, target: Target) -> list[RawPost]: async with httpx.AsyncClient() as client: - params = { 'containerid': '107603' + target} - res = await client.get('https://m.weibo.cn/api/container/getIndex?', params=params, timeout=4.0) + params = {"containerid": "107603" + target} + res = await client.get( + "https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0 + ) res_data = json.loads(res.text) - if not res_data['ok']: + if not res_data["ok"]: return [] - custom_filter: Callable[[RawPost], bool] = lambda d: d['card_type'] == 9 - return list(filter(custom_filter, res_data['data']['cards'])) + custom_filter: Callable[[RawPost], bool] = lambda d: d["card_type"] == 9 + return list(filter(custom_filter, res_data["data"]["cards"])) def get_id(self, post: RawPost) -> Any: - return post['mblog']['id'] + return post["mblog"]["id"] def filter_platform_custom(self, raw_post: RawPost) -> bool: - return raw_post['card_type'] == 9 + return raw_post["card_type"] == 9 def get_date(self, raw_post: RawPost) -> float: - created_time = datetime.strptime(raw_post['mblog']['created_at'], '%a %b %d %H:%M:%S %z %Y') + created_time = datetime.strptime( + raw_post["mblog"]["created_at"], "%a %b %d %H:%M:%S %z %Y" + ) return created_time.timestamp() def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]: "Return Tag list of given RawPost" - text = raw_post['mblog']['text'] - soup = bs(text, 'html.parser') - res = list(map( - lambda x: x[1:-1], - filter( - lambda s: s[0] == '#' and s[-1] == '#', - map(lambda x:x.text, soup.find_all('span', class_='surl-text')) - ) - )) - super_topic_img = soup.find('img', src=re.compile(r'timeline_card_small_super_default')) + text = raw_post["mblog"]["text"] + soup = bs(text, "html.parser") + res = list( + map( + lambda x: x[1:-1], + filter( + lambda s: s[0] == "#" and s[-1] == "#", + map(lambda x: x.text, soup.find_all("span", class_="surl-text")), + ), + ) + ) + super_topic_img = soup.find( + "img", src=re.compile(r"timeline_card_small_super_default") + ) if super_topic_img: try: - res.append(super_topic_img.parent.parent.find('span', class_='surl-text').text + '超话') + res.append( + super_topic_img.parent.parent.find("span", class_="surl-text").text + + "超话" + ) except: - logger.info('super_topic extract error: {}'.format(text)) + logger.info("super_topic extract error: {}".format(text)) return res def get_category(self, raw_post: RawPost) -> Category: - if raw_post['mblog'].get('retweeted_status'): + if raw_post["mblog"].get("retweeted_status"): return Category(1) - elif raw_post['mblog'].get('page_info') and raw_post['mblog']['page_info'].get('type') == 'video': + elif ( + raw_post["mblog"].get("page_info") + and raw_post["mblog"]["page_info"].get("type") == "video" + ): return Category(2) - elif raw_post['mblog'].get('pics'): + elif raw_post["mblog"].get("pics"): return Category(3) else: return Category(4) def _get_text(self, raw_text: str) -> str: - text = raw_text.replace('