From 709a3e214b92168a122edd1737c498828ab7979d Mon Sep 17 00:00:00 2001 From: felinae98 <731499577@qq.com> Date: Sat, 12 Feb 2022 10:20:02 +0800 Subject: [PATCH] format code --- .pre-commit-config.yaml | 1 + bot.py | 6 +- src/plugins/auto_agree.py | 24 +- src/plugins/nonebot_bison/__init__.py | 27 +- .../nonebot_bison/admin_page/__init__.py | 95 ++-- src/plugins/nonebot_bison/admin_page/api.py | 155 ++++--- src/plugins/nonebot_bison/admin_page/jwt.py | 19 +- .../nonebot_bison/admin_page/token_manager.py | 16 +- src/plugins/nonebot_bison/config.py | 157 +++++-- src/plugins/nonebot_bison/config_manager.py | 226 +++++---- .../nonebot_bison/platform/__init__.py | 16 +- .../nonebot_bison/platform/arknights.py | 144 +++--- .../nonebot_bison/platform/bilibili.py | 109 +++-- .../nonebot_bison/platform/ncm_artist.py | 44 +- .../nonebot_bison/platform/ncm_radio.py | 46 +- .../nonebot_bison/platform/platform.py | 154 ++++-- src/plugins/nonebot_bison/platform/rss.py | 27 +- src/plugins/nonebot_bison/platform/wechat.py | 6 +- src/plugins/nonebot_bison/platform/weibo.py | 153 +++--- src/plugins/nonebot_bison/plugin_config.py | 14 +- src/plugins/nonebot_bison/post.py | 69 +-- src/plugins/nonebot_bison/scheduler.py | 71 ++- src/plugins/nonebot_bison/send.py | 19 +- src/plugins/nonebot_bison/types.py | 12 +- tests/conftest.py | 29 +- tests/platforms/test_arknights.py | 93 ++-- tests/platforms/test_bilibili.py | 46 +- tests/platforms/test_ncm_artist.py | 53 ++- tests/platforms/test_ncm_radio.py | 58 +-- tests/platforms/test_platform.py | 437 ++++++++++-------- tests/platforms/test_weibo.py | 133 +++--- tests/platforms/utils.py | 16 +- tests/test_config_manager.py | 59 ++- tests/test_merge_pic.py | 96 ++-- tests/test_render.py | 15 +- 35 files changed, 1613 insertions(+), 1032 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b577ac8..77e9bc6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,6 +9,7 @@ repos: rev: 5.10.1 hooks: - id: isort + args: ["--profile", "black", "--filter-files"] - repo: https://github.com/psf/black rev: 22.1.0 diff --git a/bot.py b/bot.py index 6819d73..957a2e5 100644 --- a/bot.py +++ b/bot.py @@ -5,11 +5,11 @@ nonebot.init(command_start=[""]) app = nonebot.get_asgi() driver = nonebot.get_driver() -driver.register_adapter('cqhttp', CQHTTPBot) +driver.register_adapter("cqhttp", CQHTTPBot) nonebot.load_builtin_plugins() -nonebot.load_plugin('nonebot_plugin_help') -nonebot.load_plugins('src/plugins') +nonebot.load_plugin("nonebot_plugin_help") +nonebot.load_plugins("src/plugins") if __name__ == "__main__": nonebot.run(app="bot:app") diff --git a/src/plugins/auto_agree.py b/src/plugins/auto_agree.py index d4b4998..463e54a 100644 --- a/src/plugins/auto_agree.py +++ b/src/plugins/auto_agree.py @@ -1,18 +1,26 @@ -from nonebot import on_request, logger +from nonebot import logger, on_request from nonebot.adapters.cqhttp import Bot, Event +from nonebot.adapters.cqhttp.event import ( + FriendRequestEvent, + GroupRequestEvent, + RequestEvent, +) +from nonebot.adapters.cqhttp.permission import PRIVATE_FRIEND from nonebot.permission import SUPERUSER from nonebot.typing import T_State -from nonebot.adapters.cqhttp.permission import PRIVATE_FRIEND -from nonebot.adapters.cqhttp.event import FriendRequestEvent, GroupRequestEvent, RequestEvent friend_req = on_request(priority=5) + @friend_req.handle() async def add_superuser(bot: Bot, event: RequestEvent, state: T_State): - if str(event.user_id) in bot.config.superusers and event.request_type == 'private': + if str(event.user_id) in bot.config.superusers and event.request_type == "private": await event.approve(bot) - logger.info('add user {}'.format(event.user_id)) - elif event.sub_type == 'invite' and str(event.user_id) in bot.config.superusers and event.request_type == 'group': + logger.info("add user {}".format(event.user_id)) + elif ( + event.sub_type == "invite" + and str(event.user_id) in bot.config.superusers + and event.request_type == "group" + ): await event.approve(bot) - logger.info('add group {}'.format(event.group_id)) - + logger.info("add group {}".format(event.group_id)) diff --git a/src/plugins/nonebot_bison/__init__.py b/src/plugins/nonebot_bison/__init__.py index 73abbeb..353a8a3 100644 --- a/src/plugins/nonebot_bison/__init__.py +++ b/src/plugins/nonebot_bison/__init__.py @@ -1,16 +1,17 @@ import nonebot -from . import config_manager -from . import config -from . import scheduler -from . import send -from . import post -from . import platform -from . import types -from . import utils -from . import admin_page +from . import ( + admin_page, + config, + config_manager, + platform, + post, + scheduler, + send, + types, + utils, +) -__help__version__ = '0.4.3' -__help__plugin__name__ = 'nonebot_bison' -__usage__ = ('本bot可以提供b站、微博等社交媒体的消息订阅,详情' - '请查看本bot文档,或者at本bot发送“添加订阅”订阅第一个帐号') +__help__version__ = "0.4.3" +__help__plugin__name__ = "nonebot_bison" +__usage__ = "本bot可以提供b站、微博等社交媒体的消息订阅,详情" "请查看本bot文档,或者at本bot发送“添加订阅”订阅第一个帐号" diff --git a/src/plugins/nonebot_bison/admin_page/__init__.py b/src/plugins/nonebot_bison/admin_page/__init__.py index 745e0ce..0aba931 100644 --- a/src/plugins/nonebot_bison/admin_page/__init__.py +++ b/src/plugins/nonebot_bison/admin_page/__init__.py @@ -1,8 +1,9 @@ -from dataclasses import dataclass import os +from dataclasses import dataclass from pathlib import Path from typing import Union +import socketio from fastapi.staticfiles import StaticFiles from nonebot import get_driver, on_command from nonebot.adapters.cqhttp.bot import Bot @@ -11,7 +12,6 @@ from nonebot.drivers.fastapi import Driver from nonebot.log import logger from nonebot.rule import to_me from nonebot.typing import T_State -import socketio from ..plugin_config import plugin_config from .api import ( @@ -27,21 +27,21 @@ from .api import ( from .jwt import load_jwt from .token_manager import token_manager as tm -URL_BASE = '/bison/' -GLOBAL_CONF_URL = f'{URL_BASE}api/global_conf' -AUTH_URL = f'{URL_BASE}api/auth' -SUBSCRIBE_URL = f'{URL_BASE}api/subs' -GET_TARGET_NAME_URL = f'{URL_BASE}api/target_name' -TEST_URL = f'{URL_BASE}test' +URL_BASE = "/bison/" +GLOBAL_CONF_URL = f"{URL_BASE}api/global_conf" +AUTH_URL = f"{URL_BASE}api/auth" +SUBSCRIBE_URL = f"{URL_BASE}api/subs" +GET_TARGET_NAME_URL = f"{URL_BASE}api/target_name" +TEST_URL = f"{URL_BASE}test" STATIC_PATH = (Path(__file__).parent / "dist").resolve() sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") socket_app = socketio.ASGIApp(sio, socketio_path="socket") -class SinglePageApplication(StaticFiles): - def __init__(self, directory: os.PathLike, index='index.html'): +class SinglePageApplication(StaticFiles): + def __init__(self, directory: os.PathLike, index="index.html"): self.index = index super().__init__(directory=directory, packages=None, html=True, check_dir=True) @@ -51,12 +51,13 @@ class SinglePageApplication(StaticFiles): return await super().lookup_path(self.index) return (full_path, stat_res) -def register_router_fastapi(driver: Driver, socketio): - from fastapi.security import OAuth2PasswordBearer - from fastapi.param_functions import Depends - from fastapi import HTTPException, status - oath_scheme = OAuth2PasswordBearer(tokenUrl='token') +def register_router_fastapi(driver: Driver, socketio): + from fastapi import HTTPException, status + from fastapi.param_functions import Depends + from fastapi.security import OAuth2PasswordBearer + + oath_scheme = OAuth2PasswordBearer(tokenUrl="token") async def get_jwt_obj(token: str = Depends(oath_scheme)): obj = load_jwt(token) @@ -64,10 +65,12 @@ def register_router_fastapi(driver: Driver, socketio): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) return obj - async def check_group_permission(groupNumber: str, token_obj: dict = Depends(get_jwt_obj)): - groups = token_obj['groups'] + async def check_group_permission( + groupNumber: str, token_obj: dict = Depends(get_jwt_obj) + ): + groups = token_obj["groups"] for group in groups: - if int(groupNumber) == group['id']: + if int(groupNumber) == group["id"]: return raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) @@ -88,17 +91,35 @@ def register_router_fastapi(driver: Driver, socketio): @app.get(SUBSCRIBE_URL) async def subs(jwt_obj: dict = Depends(get_jwt_obj)): return await get_subs_info(jwt_obj) + @app.get(GET_TARGET_NAME_URL) - async def _get_target_name(platformName: str, target: str, jwt_obj: dict = Depends(get_jwt_obj)): + async def _get_target_name( + platformName: str, target: str, jwt_obj: dict = Depends(get_jwt_obj) + ): return await get_target_name(platformName, target, jwt_obj) + @app.post(SUBSCRIBE_URL, dependencies=[Depends(check_group_permission)]) async def _add_group_subs(groupNumber: str, req: AddSubscribeReq): - return await add_group_sub(group_number=groupNumber, platform_name=req.platformName, - target=req.target, target_name=req.targetName, cats=req.cats, tags=req.tags) + return await add_group_sub( + group_number=groupNumber, + platform_name=req.platformName, + target=req.target, + target_name=req.targetName, + cats=req.cats, + tags=req.tags, + ) + @app.patch(SUBSCRIBE_URL, dependencies=[Depends(check_group_permission)]) async def _update_group_subs(groupNumber: str, req: AddSubscribeReq): - return await update_group_sub(group_number=groupNumber, platform_name=req.platformName, - target=req.target, target_name=req.targetName, cats=req.cats, tags=req.tags) + return await update_group_sub( + group_number=groupNumber, + platform_name=req.platformName, + target=req.target, + target_name=req.targetName, + cats=req.cats, + tags=req.tags, + ) + @app.delete(SUBSCRIBE_URL, dependencies=[Depends(check_group_permission)]) async def _del_group_subs(groupNumber: str, target: str, platformName: str): return await del_group_sub(groupNumber, platformName, target) @@ -108,8 +129,8 @@ def register_router_fastapi(driver: Driver, socketio): def init(): driver = get_driver() - if driver.type == 'fastapi': - assert(isinstance(driver, Driver)) + if driver.type == "fastapi": + assert isinstance(driver, Driver) register_router_fastapi(driver, socket_app) else: logger.warning(f"Driver {driver.type} not supported") @@ -118,19 +139,25 @@ def init(): port = driver.config.port if host in ["0.0.0.0", "127.0.0.1"]: host = "localhost" - logger.opt(colors=True).info(f"Nonebot test frontend will be running at: " - f"http://{host}:{port}{URL_BASE}") + logger.opt(colors=True).info( + f"Nonebot test frontend will be running at: " + f"http://{host}:{port}{URL_BASE}" + ) -if (STATIC_PATH / 'index.html').exists(): + +if (STATIC_PATH / "index.html").exists(): init() - get_token = on_command('后台管理', rule=to_me(), priority=5) + get_token = on_command("后台管理", rule=to_me(), priority=5) + @get_token.handle() async def send_token(bot: "Bot", event: PrivateMessageEvent, state: T_State): token = tm.get_user_token((event.get_user_id(), event.sender.nickname)) - await get_token.finish(f'请访问: {plugin_config.bison_outer_url}auth/{token}') - get_token.__help__name__ = '获取后台管理地址' - get_token.__help__info__ = ('获取管理bot后台的地址,该地址会' - '在一段时间过后过期,请不要泄漏该地址') + await get_token.finish(f"请访问: {plugin_config.bison_outer_url}auth/{token}") + + get_token.__help__name__ = "获取后台管理地址" + get_token.__help__info__ = "获取管理bot后台的地址,该地址会" "在一段时间过后过期,请不要泄漏该地址" else: - logger.warning("Frontend file not found, please compile it or use docker or pypi version") + logger.warning( + "Frontend file not found, please compile it or use docker or pypi version" + ) diff --git a/src/plugins/nonebot_bison/admin_page/api.py b/src/plugins/nonebot_bison/admin_page/api.py index ff17d62..b2f519f 100644 --- a/src/plugins/nonebot_bison/admin_page/api.py +++ b/src/plugins/nonebot_bison/admin_page/api.py @@ -6,110 +6,141 @@ from ..platform import check_sub_target, platform_manager from .jwt import pack_jwt from .token_manager import token_manager + async def test(): return {"status": 200, "text": "test"} + async def get_global_conf(): res = {} for platform_name, platform in platform_manager.items(): res[platform_name] = { - 'platformName': platform_name, - 'categories': platform.categories, - 'enabledTag': platform.enable_tag, - 'name': platform.name, - 'hasTarget': getattr(platform, 'has_target') - } - return { 'platformConf': res } + "platformName": platform_name, + "categories": platform.categories, + "enabledTag": platform.enable_tag, + "name": platform.name, + "hasTarget": getattr(platform, "has_target"), + } + return {"platformConf": res} + async def get_admin_groups(qq: int): bot = nonebot.get_bot() - groups = await bot.call_api('get_group_list') + groups = await bot.call_api("get_group_list") res = [] for group in groups: - group_id = group['group_id'] - users = await bot.call_api('get_group_member_list', group_id=group_id) + group_id = group["group_id"] + users = await bot.call_api("get_group_member_list", group_id=group_id) for user in users: - if user['user_id'] == qq and user['role'] in ('owner', 'admin'): - res.append({'id': group_id, 'name': group['group_name']}) + if user["user_id"] == qq and user["role"] in ("owner", "admin"): + res.append({"id": group_id, "name": group["group_name"]}) return res + async def auth(token: str): if qq_tuple := token_manager.get_user(token): qq, nickname = qq_tuple bot = nonebot.get_bot() - assert(isinstance(bot, Bot)) - groups = await bot.call_api('get_group_list') + assert isinstance(bot, Bot) + groups = await bot.call_api("get_group_list") if str(qq) in nonebot.get_driver().config.superusers: jwt_obj = { - 'id': str(qq), - 'groups': list(map( - lambda info: {'id': info['group_id'], 'name': info['group_name']}, - groups)), - } + "id": str(qq), + "groups": list( + map( + lambda info: { + "id": info["group_id"], + "name": info["group_name"], + }, + groups, + ) + ), + } ret_obj = { - 'type': 'admin', - 'name': nickname, - 'id': str(qq), - 'token': pack_jwt(jwt_obj) - } - return { 'status': 200, **ret_obj } + "type": "admin", + "name": nickname, + "id": str(qq), + "token": pack_jwt(jwt_obj), + } + return {"status": 200, **ret_obj} if admin_groups := await get_admin_groups(int(qq)): - jwt_obj = { - 'id': str(qq), - 'groups': admin_groups - } + jwt_obj = {"id": str(qq), "groups": admin_groups} ret_obj = { - 'type': 'user', - 'name': nickname, - 'id': str(qq), - 'token': pack_jwt(jwt_obj) - } - return { 'status': 200, **ret_obj } + "type": "user", + "name": nickname, + "id": str(qq), + "token": pack_jwt(jwt_obj), + } + return {"status": 200, **ret_obj} else: - return { 'status': 400, 'type': '', 'name': '', 'id': '', 'token': '' } + return {"status": 400, "type": "", "name": "", "id": "", "token": ""} else: - return { 'status': 400, 'type': '', 'name': '', 'id': '', 'token': '' } + return {"status": 400, "type": "", "name": "", "id": "", "token": ""} + async def get_subs_info(jwt_obj: dict): - groups = jwt_obj['groups'] + groups = jwt_obj["groups"] res = {} for group in groups: - group_id = group['id'] + group_id = group["id"] config = Config() - subs = list(map(lambda sub: { - 'platformName': sub['target_type'], 'target': sub['target'], 'targetName': sub['target_name'], 'cats': sub['cats'], 'tags': sub['tags'] - }, config.list_subscribe(group_id, 'group'))) - res[group_id] = { - 'name': group['name'], - 'subscribes': subs - } + subs = list( + map( + lambda sub: { + "platformName": sub["target_type"], + "target": sub["target"], + "targetName": sub["target_name"], + "cats": sub["cats"], + "tags": sub["tags"], + }, + config.list_subscribe(group_id, "group"), + ) + ) + res[group_id] = {"name": group["name"], "subscribes": subs} return res -async def get_target_name(platform_name: str, target: str, jwt_obj: dict): - return {'targetName': await check_sub_target(platform_name, target)} -async def add_group_sub(group_number: str, platform_name: str, target: str, - target_name: str, cats: list[int], tags: list[str]): +async def get_target_name(platform_name: str, target: str, jwt_obj: dict): + return {"targetName": await check_sub_target(platform_name, target)} + + +async def add_group_sub( + group_number: str, + platform_name: str, + target: str, + target_name: str, + cats: list[int], + tags: list[str], +): config = Config() - config.add_subscribe(int(group_number), 'group', target, target_name, platform_name, cats, tags) - return { 'status': 200, 'msg': '' } + config.add_subscribe( + int(group_number), "group", target, target_name, platform_name, cats, tags + ) + return {"status": 200, "msg": ""} + async def del_group_sub(group_number: str, platform_name: str, target: str): config = Config() try: - config.del_subscribe(int(group_number), 'group', target, platform_name) + config.del_subscribe(int(group_number), "group", target, platform_name) except (NoSuchUserException, NoSuchSubscribeException): - return { 'status': 400, 'msg': '删除错误' } - return { 'status': 200, 'msg': '' } + return {"status": 400, "msg": "删除错误"} + return {"status": 200, "msg": ""} -async def update_group_sub(group_number: str, platform_name: str, target: str, - target_name: str, cats: list[int], tags: list[str]): +async def update_group_sub( + group_number: str, + platform_name: str, + target: str, + target_name: str, + cats: list[int], + tags: list[str], +): config = Config() try: - config.update_subscribe(int(group_number), 'group', - target, target_name, platform_name, cats, tags) + config.update_subscribe( + int(group_number), "group", target, target_name, platform_name, cats, tags + ) except (NoSuchUserException, NoSuchSubscribeException): - return { 'status': 400, 'msg': '更新错误' } - return { 'status': 200, 'msg': '' } - + return {"status": 400, "msg": "更新错误"} + return {"status": 200, "msg": ""} diff --git a/src/plugins/nonebot_bison/admin_page/jwt.py b/src/plugins/nonebot_bison/admin_page/jwt.py index c607747..661621a 100644 --- a/src/plugins/nonebot_bison/admin_page/jwt.py +++ b/src/plugins/nonebot_bison/admin_page/jwt.py @@ -1,20 +1,23 @@ +import datetime import random import string from typing import Optional -import jwt -import datetime -_key = ''.join(random.SystemRandom().choice(string.ascii_letters) for _ in range(16)) +import jwt + +_key = "".join(random.SystemRandom().choice(string.ascii_letters) for _ in range(16)) + def pack_jwt(obj: dict) -> str: return jwt.encode( - {'exp': datetime.datetime.utcnow() + datetime.timedelta(hours=1), **obj}, - _key, algorithm='HS256' - ) + {"exp": datetime.datetime.utcnow() + datetime.timedelta(hours=1), **obj}, + _key, + algorithm="HS256", + ) + def load_jwt(token: str) -> Optional[dict]: try: - return jwt.decode(token, _key, algorithms=['HS256']) + return jwt.decode(token, _key, algorithms=["HS256"]) except: return None - diff --git a/src/plugins/nonebot_bison/admin_page/token_manager.py b/src/plugins/nonebot_bison/admin_page/token_manager.py index 47e3927..e540656 100644 --- a/src/plugins/nonebot_bison/admin_page/token_manager.py +++ b/src/plugins/nonebot_bison/admin_page/token_manager.py @@ -1,24 +1,26 @@ -from typing import Optional -from expiringdict import ExpiringDict import random -import string +import string +from typing import Optional + +from expiringdict import ExpiringDict + class TokenManager: - def __init__(self): - self.token_manager = ExpiringDict(max_len=100, max_age_seconds=60*10) + self.token_manager = ExpiringDict(max_len=100, max_age_seconds=60 * 10) def get_user(self, token: str) -> Optional[tuple]: res = self.token_manager.get(token) - assert(res is None or isinstance(res, tuple)) + assert res is None or isinstance(res, tuple) return res def save_user(self, token: str, qq: tuple) -> None: self.token_manager[token] = qq def get_user_token(self, qq: tuple) -> str: - token = ''.join(random.choices(string.ascii_letters + string.digits, k=16)) + token = "".join(random.choices(string.ascii_letters + string.digits, k=16)) self.save_user(token, qq) return token + token_manager = TokenManager() diff --git a/src/plugins/nonebot_bison/config.py b/src/plugins/nonebot_bison/config.py index 0712fcb..a021aeb 100644 --- a/src/plugins/nonebot_bison/config.py +++ b/src/plugins/nonebot_bison/config.py @@ -1,6 +1,6 @@ +import os from collections import defaultdict from os import path -import os from typing import DefaultDict, Literal, Mapping, TypedDict import nonebot @@ -14,26 +14,30 @@ from .utils import Singleton supported_target_type = platform_manager.keys() + def get_config_path() -> str: if plugin_config.bison_config_path: data_dir = plugin_config.bison_config_path else: working_dir = os.getcwd() - data_dir = path.join(working_dir, 'data') + data_dir = path.join(working_dir, "data") if not path.isdir(data_dir): os.makedirs(data_dir) - old_path = path.join(data_dir, 'hk_reporter.json') - new_path = path.join(data_dir, 'bison.json') + old_path = path.join(data_dir, "hk_reporter.json") + new_path = path.join(data_dir, "bison.json") if os.path.exists(old_path) and not os.path.exists(new_path): os.rename(old_path, new_path) return new_path + class NoSuchUserException(Exception): pass + class NoSuchSubscribeException(Exception): pass + class SubscribeContent(TypedDict): target: str target_type: str @@ -41,75 +45,105 @@ class SubscribeContent(TypedDict): cats: list[int] tags: list[str] + class ConfigContent(TypedDict): user: str user_type: Literal["group", "private"] subs: list[SubscribeContent] + class Config(metaclass=Singleton): migrate_version = 2 - + def __init__(self): - self.db = TinyDB(get_config_path(), encoding='utf-8') - self.kv_config = self.db.table('kv') - self.user_target = self.db.table('user_target') + self.db = TinyDB(get_config_path(), encoding="utf-8") + self.kv_config = self.db.table("kv") + self.user_target = self.db.table("user_target") self.target_user_cache: dict[str, defaultdict[Target, list[User]]] = {} self.target_user_cat_cache = {} self.target_user_tag_cache = {} self.target_list = {} self.next_index: DefaultDict[str, int] = defaultdict(lambda: 0) - - def add_subscribe(self, user, user_type, target, target_name, target_type, cats, tags): + + def add_subscribe( + self, user, user_type, target, target_name, target_type, cats, tags + ): user_query = Query() query = (user_query.user == user) & (user_query.user_type == user_type) - if (user_data := self.user_target.get(query)): + if user_data := self.user_target.get(query): # update - subs: list = user_data.get('subs', []) - subs.append({"target": target, "target_type": target_type, 'target_name': target_name, 'cats': cats, 'tags': tags}) + subs: list = user_data.get("subs", []) + subs.append( + { + "target": target, + "target_type": target_type, + "target_name": target_name, + "cats": cats, + "tags": tags, + } + ) self.user_target.update({"subs": subs}, query) else: # insert - self.user_target.insert({ - 'user': user, 'user_type': user_type, - 'subs': [{'target': target, 'target_type': target_type, 'target_name': target_name, 'cats': cats, 'tags': tags }] - }) + self.user_target.insert( + { + "user": user, + "user_type": user_type, + "subs": [ + { + "target": target, + "target_type": target_type, + "target_name": target_name, + "cats": cats, + "tags": tags, + } + ], + } + ) self.update_send_cache() def list_subscribe(self, user, user_type) -> list[SubscribeContent]: query = Query() - if user_sub := self.user_target.get((query.user == user) & (query.user_type ==user_type)): - return user_sub['subs'] + if user_sub := self.user_target.get( + (query.user == user) & (query.user_type == user_type) + ): + return user_sub["subs"] return [] def get_all_subscribe(self): return self.user_target - + def del_subscribe(self, user, user_type, target, target_type): user_query = Query() query = (user_query.user == user) & (user_query.user_type == user_type) if not (query_res := self.user_target.get(query)): raise NoSuchUserException() - subs = query_res.get('subs', []) + subs = query_res.get("subs", []) for idx, sub in enumerate(subs): - if sub.get('target') == target and sub.get('target_type') == target_type: + if sub.get("target") == target and sub.get("target_type") == target_type: subs.pop(idx) - self.user_target.update({'subs': subs}, query) + self.user_target.update({"subs": subs}, query) self.update_send_cache() return raise NoSuchSubscribeException() - def update_subscribe(self, user, user_type, target, target_name, target_type, cats, tags): + def update_subscribe( + self, user, user_type, target, target_name, target_type, cats, tags + ): user_query = Query() query = (user_query.user == user) & (user_query.user_type == user_type) - if (user_data := self.user_target.get(query)): + if user_data := self.user_target.get(query): # update - subs: list = user_data.get('subs', []) + subs: list = user_data.get("subs", []) find_flag = False for item in subs: - if item['target'] == target and item['target_type'] == target_type: - item['target_name'], item['cats'], item['tags'] = \ - target_name, cats, tags + if item["target"] == target and item["target_type"] == target_type: + item["target_name"], item["cats"], item["tags"] = ( + target_name, + cats, + tags, + ) find_flag = True break if not find_flag: @@ -121,33 +155,58 @@ class Config(metaclass=Singleton): def update_send_cache(self): res = {target_type: defaultdict(list) for target_type in supported_target_type} - cat_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} - tag_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} + cat_res = { + target_type: defaultdict(lambda: defaultdict(list)) + for target_type in supported_target_type + } + tag_res = { + target_type: defaultdict(lambda: defaultdict(list)) + for target_type in supported_target_type + } # res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} to_del = [] for user in self.user_target.all(): - for sub in user.get('subs', []): - if not sub.get('target_type') in supported_target_type: - to_del.append({'user': user['user'], 'user_type': user['user_type'], 'target': sub['target'], 'target_type': sub['target_type']}) + for sub in user.get("subs", []): + if not sub.get("target_type") in supported_target_type: + to_del.append( + { + "user": user["user"], + "user_type": user["user_type"], + "target": sub["target"], + "target_type": sub["target_type"], + } + ) continue - res[sub['target_type']][sub['target']].append(User(user['user'], user['user_type'])) - cat_res[sub['target_type']][sub['target']]['{}-{}'.format(user['user_type'], user['user'])] = sub['cats'] - tag_res[sub['target_type']][sub['target']]['{}-{}'.format(user['user_type'], user['user'])] = sub['tags'] + res[sub["target_type"]][sub["target"]].append( + User(user["user"], user["user_type"]) + ) + cat_res[sub["target_type"]][sub["target"]][ + "{}-{}".format(user["user_type"], user["user"]) + ] = sub["cats"] + tag_res[sub["target_type"]][sub["target"]][ + "{}-{}".format(user["user_type"], user["user"]) + ] = sub["tags"] self.target_user_cache = res self.target_user_cat_cache = cat_res self.target_user_tag_cache = tag_res for target_type in self.target_user_cache: - self.target_list[target_type] = list(self.target_user_cache[target_type].keys()) - - logger.info(f'Deleting {to_del}') + self.target_list[target_type] = list( + self.target_user_cache[target_type].keys() + ) + + logger.info(f"Deleting {to_del}") for d in to_del: self.del_subscribe(**d) def get_sub_category(self, target_type, target, user_type, user): - return self.target_user_cat_cache[target_type][target]['{}-{}'.format(user_type, user)] + return self.target_user_cat_cache[target_type][target][ + "{}-{}".format(user_type, user) + ] def get_sub_tags(self, target_type, target, user_type, user): - return self.target_user_tag_cache[target_type][target]['{}-{}'.format(user_type, user)] + return self.target_user_tag_cache[target_type][target][ + "{}-{}".format(user_type, user) + ] def get_next_target(self, target_type): # FIXME 插入或删除target后对队列的影响(但是并不是大问题 @@ -158,25 +217,27 @@ class Config(metaclass=Singleton): self.next_index[target_type] += 1 return res + def start_up(): config = Config() - if not (search_res := config.kv_config.search(Query().name=="version")): + if not (search_res := config.kv_config.search(Query().name == "version")): config.kv_config.insert({"name": "version", "value": config.migrate_version}) elif search_res[0].get("value") < config.migrate_version: query = Query() - version_query = (query.name == 'version') + version_query = query.name == "version" cur_version = search_res[0].get("value") if cur_version == 1: cur_version = 2 for user_conf in config.user_target.all(): conf_id = user_conf.doc_id - subs = user_conf['subs'] + subs = user_conf["subs"] for sub in subs: - sub['cats'] = [] - sub['tags'] = [] - config.user_target.update({'subs': subs}, doc_ids=[conf_id]) + sub["cats"] = [] + sub["tags"] = [] + config.user_target.update({"subs": subs}, doc_ids=[conf_id]) config.kv_config.update({"value": config.migrate_version}, version_query) # do migration config.update_send_cache() + nonebot.get_driver().on_startup(start_up) diff --git a/src/plugins/nonebot_bison/config_manager.py b/src/plugins/nonebot_bison/config_manager.py index 436c0c6..8e96ea6 100644 --- a/src/plugins/nonebot_bison/config_manager.py +++ b/src/plugins/nonebot_bison/config_manager.py @@ -7,7 +7,7 @@ from nonebot.adapters.cqhttp import Bot, Event, GroupMessageEvent from nonebot.adapters.cqhttp.message import Message from nonebot.adapters.cqhttp.permission import GROUP_ADMIN, GROUP_MEMBER, GROUP_OWNER from nonebot.matcher import Matcher -from nonebot.permission import Permission, SUPERUSER +from nonebot.permission import SUPERUSER, Permission from nonebot.rule import to_me from nonebot.typing import T_State @@ -16,137 +16,184 @@ from .platform import check_sub_target, platform_manager from .types import Target from .utils import parse_text + def _gen_prompt_template(prompt: str): - if hasattr(Message, 'template'): + if hasattr(Message, "template"): return Message.template(prompt) return prompt -common_platform = [p.platform_name for p in \ - filter(lambda platform: platform.enabled and platform.is_common, - platform_manager.values()) - ] -help_match = on_command('help', rule=to_me(), priority=5) +common_platform = [ + p.platform_name + for p in filter( + lambda platform: platform.enabled and platform.is_common, + platform_manager.values(), + ) +] + +help_match = on_command("help", rule=to_me(), priority=5) + + @help_match.handle() async def send_help(bot: Bot, event: Event, state: T_State): - message = '使用方法:\n@bot 添加订阅(仅管理员)\n@bot 查询订阅\n@bot 删除订阅(仅管理员)' + message = "使用方法:\n@bot 添加订阅(仅管理员)\n@bot 查询订阅\n@bot 删除订阅(仅管理员)" await help_match.finish(Message(await parse_text(message))) def do_add_sub(add_sub: Type[Matcher]): @add_sub.handle() async def init_promote(bot: Bot, event: Event, state: T_State): - state['_prompt'] = '请输入想要订阅的平台,目前支持,请输入冒号左边的名称:\n' + \ - ''.join(['{}:{}\n'.format(platform_name, platform_manager[platform_name].name) \ - for platform_name in common_platform]) + \ - '要查看全部平台请输入:“全部”' + state["_prompt"] = ( + "请输入想要订阅的平台,目前支持,请输入冒号左边的名称:\n" + + "".join( + [ + "{}:{}\n".format( + platform_name, platform_manager[platform_name].name + ) + for platform_name in common_platform + ] + ) + + "要查看全部平台请输入:“全部”" + ) - async def parse_platform(bot: AbstractBot, event: AbstractEvent, state: T_State) -> None: + async def parse_platform( + bot: AbstractBot, event: AbstractEvent, state: T_State + ) -> None: platform = str(event.get_message()).strip() - if platform == '全部': - message = '全部平台\n' + \ - '\n'.join(['{}:{}'.format(platform_name, platform.name) \ - for platform_name, platform in platform_manager.items()]) + if platform == "全部": + message = "全部平台\n" + "\n".join( + [ + "{}:{}".format(platform_name, platform.name) + for platform_name, platform in platform_manager.items() + ] + ) await add_sub.reject(message) elif platform in platform_manager: - state['platform'] = platform + state["platform"] = platform else: - await add_sub.reject('平台输入错误') + await add_sub.reject("平台输入错误") - @add_sub.got('platform', _gen_prompt_template('{_prompt}'), parse_platform) + @add_sub.got("platform", _gen_prompt_template("{_prompt}"), parse_platform) @add_sub.handle() async def init_id(bot: Bot, event: Event, state: T_State): - if platform_manager[state['platform']].has_target: - state['_prompt'] = '请输入订阅用户的id,详情查阅https://nonebot-bison.vercel.app/usage/#%E6%89%80%E6%94%AF%E6%8C%81%E5%B9%B3%E5%8F%B0%E7%9A%84uid' + if platform_manager[state["platform"]].has_target: + state[ + "_prompt" + ] = "请输入订阅用户的id,详情查阅https://nonebot-bison.vercel.app/usage/#%E6%89%80%E6%94%AF%E6%8C%81%E5%B9%B3%E5%8F%B0%E7%9A%84uid" else: - state['id'] = 'default' - state['name'] = await platform_manager[state['platform']].get_target_name(Target('')) + state["id"] = "default" + state["name"] = await platform_manager[state["platform"]].get_target_name( + Target("") + ) async def parse_id(bot: AbstractBot, event: AbstractEvent, state: T_State): target = str(event.get_message()).strip() try: - name = await check_sub_target(state['platform'], target) + name = await check_sub_target(state["platform"], target) if not name: - await add_sub.reject('id输入错误') - state['id'] = target - state['name'] = name + await add_sub.reject("id输入错误") + state["id"] = target + state["name"] = name except: - await add_sub.reject('id输入错误') + await add_sub.reject("id输入错误") - @add_sub.got('id', _gen_prompt_template('{_prompt}'), parse_id) + @add_sub.got("id", _gen_prompt_template("{_prompt}"), parse_id) @add_sub.handle() async def init_cat(bot: Bot, event: Event, state: T_State): - if not platform_manager[state['platform']].categories: - state['cats'] = [] + if not platform_manager[state["platform"]].categories: + state["cats"] = [] return - state['_prompt'] = '请输入要订阅的类别,以空格分隔,支持的类别有:{}'.format( - ' '.join(list(platform_manager[state['platform']].categories.values()))) + state["_prompt"] = "请输入要订阅的类别,以空格分隔,支持的类别有:{}".format( + " ".join(list(platform_manager[state["platform"]].categories.values())) + ) async def parser_cats(bot: AbstractBot, event: AbstractEvent, state: T_State): res = [] for cat in str(event.get_message()).strip().split(): - if cat not in platform_manager[state['platform']].reverse_category: - await add_sub.reject('不支持 {}'.format(cat)) - res.append(platform_manager[state['platform']].reverse_category[cat]) - state['cats'] = res + if cat not in platform_manager[state["platform"]].reverse_category: + await add_sub.reject("不支持 {}".format(cat)) + res.append(platform_manager[state["platform"]].reverse_category[cat]) + state["cats"] = res - @add_sub.got('cats', _gen_prompt_template('{_prompt}'), parser_cats) + @add_sub.got("cats", _gen_prompt_template("{_prompt}"), parser_cats) @add_sub.handle() async def init_tag(bot: Bot, event: Event, state: T_State): - if not platform_manager[state['platform']].enable_tag: - state['tags'] = [] + if not platform_manager[state["platform"]].enable_tag: + state["tags"] = [] return - state['_prompt'] = '请输入要订阅的tag,订阅所有tag输入"全部标签"' + state["_prompt"] = '请输入要订阅的tag,订阅所有tag输入"全部标签"' async def parser_tags(bot: AbstractBot, event: AbstractEvent, state: T_State): - if str(event.get_message()).strip() == '全部标签': - state['tags'] = [] + if str(event.get_message()).strip() == "全部标签": + state["tags"] = [] else: - state['tags'] = str(event.get_message()).strip().split() + state["tags"] = str(event.get_message()).strip().split() - @add_sub.got('tags', _gen_prompt_template('{_prompt}'), parser_tags) + @add_sub.got("tags", _gen_prompt_template("{_prompt}"), parser_tags) @add_sub.handle() async def add_sub_process(bot: Bot, event: Event, state: T_State): config = Config() - config.add_subscribe(state.get('_user_id') or event.group_id, user_type='group', - target=state['id'], - target_name=state['name'], target_type=state['platform'], - cats=state.get('cats', []), tags=state.get('tags', [])) - await add_sub.finish('添加 {} 成功'.format(state['name'])) + config.add_subscribe( + state.get("_user_id") or event.group_id, + user_type="group", + target=state["id"], + target_name=state["name"], + target_type=state["platform"], + cats=state.get("cats", []), + tags=state.get("tags", []), + ) + await add_sub.finish("添加 {} 成功".format(state["name"])) + def do_query_sub(query_sub: Type[Matcher]): @query_sub.handle() async def _(bot: Bot, event: Event, state: T_State): config: Config = Config() - sub_list = config.list_subscribe(state.get('_user_id') or event.group_id, "group") - res = '订阅的帐号为:\n' + sub_list = config.list_subscribe( + state.get("_user_id") or event.group_id, "group" + ) + res = "订阅的帐号为:\n" for sub in sub_list: - res += '{} {} {}'.format(sub['target_type'], sub['target_name'], sub['target']) - platform = platform_manager[sub['target_type']] + res += "{} {} {}".format( + sub["target_type"], sub["target_name"], sub["target"] + ) + platform = platform_manager[sub["target_type"]] if platform.categories: - res += ' [{}]'.format(', '.join(map(lambda x: platform.categories[x], sub['cats']))) + res += " [{}]".format( + ", ".join(map(lambda x: platform.categories[x], sub["cats"])) + ) if platform.enable_tag: - res += ' {}'.format(', '.join(sub['tags'])) - res += '\n' + res += " {}".format(", ".join(sub["tags"])) + res += "\n" await query_sub.finish(Message(await parse_text(res))) + def do_del_sub(del_sub: Type[Matcher]): @del_sub.handle() async def send_list(bot: Bot, event: Event, state: T_State): config: Config = Config() - sub_list = config.list_subscribe(state.get('_user_id') or event.group_id, "group") - res = '订阅的帐号为:\n' - state['sub_table'] = {} + sub_list = config.list_subscribe( + state.get("_user_id") or event.group_id, "group" + ) + res = "订阅的帐号为:\n" + state["sub_table"] = {} for index, sub in enumerate(sub_list, 1): - state['sub_table'][index] = {'target_type': sub['target_type'], 'target': sub['target']} - res += '{} {} {} {}\n'.format(index, sub['target_type'], sub['target_name'], sub['target']) - platform = platform_manager[sub['target_type']] + state["sub_table"][index] = { + "target_type": sub["target_type"], + "target": sub["target"], + } + res += "{} {} {} {}\n".format( + index, sub["target_type"], sub["target_name"], sub["target"] + ) + platform = platform_manager[sub["target_type"]] if platform.categories: - res += ' [{}]'.format(', '.join(map(lambda x: platform.categories[x], sub['cats']))) + res += " [{}]".format( + ", ".join(map(lambda x: platform.categories[x], sub["cats"])) + ) if platform.enable_tag: - res += ' {}'.format(', '.join(sub['tags'])) - res += '\n' - res += '请输入要删除的订阅的序号' + res += " {}".format(", ".join(sub["tags"])) + res += "\n" + res += "请输入要删除的订阅的序号" await bot.send(event=event, message=Message(await parse_text(res))) @del_sub.receive() @@ -154,39 +201,60 @@ def do_del_sub(del_sub: Type[Matcher]): try: index = int(str(event.get_message()).strip()) config = Config() - config.del_subscribe(state.get('_user_id') or event.group_id, 'group', **state['sub_table'][index]) + config.del_subscribe( + state.get("_user_id") or event.group_id, + "group", + **state["sub_table"][index] + ) except Exception as e: - await del_sub.reject('删除错误') + await del_sub.reject("删除错误") logger.warning(e) else: - await del_sub.finish('删除成功') + await del_sub.finish("删除成功") + async def parse_group_number(bot: AbstractBot, event: AbstractEvent, state: T_State): state[state["_current_key"]] = int(str(event.get_message())) -add_sub_matcher = on_command("添加订阅", rule=to_me(), permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, priority=5) +add_sub_matcher = on_command( + "添加订阅", rule=to_me(), permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, priority=5 +) do_add_sub(add_sub_matcher) -manage_add_sub_mather = on_command('管理-添加订阅', permission=SUPERUSER, priority=5) -@manage_add_sub_mather.got('_user_id', "群号", parse_group_number) +manage_add_sub_mather = on_command("管理-添加订阅", permission=SUPERUSER, priority=5) + + +@manage_add_sub_mather.got("_user_id", "群号", parse_group_number) async def handle(bot: Bot, event: Event, state: T_State): pass + + do_add_sub(manage_add_sub_mather) query_sub_macher = on_command("查询订阅", rule=to_me(), priority=5) do_query_sub(query_sub_macher) -manage_query_sub_mather = on_command('管理-查询订阅', permission=SUPERUSER, priority=5) -@manage_query_sub_mather.got('_user_id', "群号", parse_group_number) +manage_query_sub_mather = on_command("管理-查询订阅", permission=SUPERUSER, priority=5) + + +@manage_query_sub_mather.got("_user_id", "群号", parse_group_number) async def handle(bot: Bot, event: Event, state: T_State): pass + + do_query_sub(manage_query_sub_mather) -del_sub_macher = on_command("删除订阅", rule=to_me(), permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, priority=5) +del_sub_macher = on_command( + "删除订阅", rule=to_me(), permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, priority=5 +) do_del_sub(del_sub_macher) -manage_del_sub_mather = on_command('管理-删除订阅', permission=SUPERUSER, priority=5) -@manage_del_sub_mather.got('_user_id', "群号", parse_group_number) +manage_del_sub_mather = on_command("管理-删除订阅", permission=SUPERUSER, priority=5) + + +@manage_del_sub_mather.got("_user_id", "群号", parse_group_number) async def handle(bot: Bot, event: Event, state: T_State): pass + + do_del_sub(manage_del_sub_mather) diff --git a/src/plugins/nonebot_bison/platform/__init__.py b/src/plugins/nonebot_bison/platform/__init__.py index 028f002..60b5e32 100644 --- a/src/plugins/nonebot_bison/platform/__init__.py +++ b/src/plugins/nonebot_bison/platform/__init__.py @@ -1,18 +1,19 @@ from collections import defaultdict - -from .platform import Platform, NoTargetGroup -from pkgutil import iter_modules -from pathlib import Path from importlib import import_module +from pathlib import Path +from pkgutil import iter_modules + +from .platform import NoTargetGroup, Platform _package_dir = str(Path(__file__).resolve().parent) for (_, module_name, _) in iter_modules([_package_dir]): - import_module(f'{__name__}.{module_name}') + import_module(f"{__name__}.{module_name}") async def check_sub_target(target_type, target): return await platform_manager[target_type].get_target_name(target) + _platform_list = defaultdict(list) for _platform in Platform.registry: if not _platform.enabled: @@ -24,5 +25,6 @@ for name, platform_list in _platform_list.items(): if len(platform_list) == 1: platform_manager[name] = platform_list[0]() else: - platform_manager[name] = NoTargetGroup([_platform() for _platform in platform_list]) - + platform_manager[name] = NoTargetGroup( + [_platform() for _platform in platform_list] + ) diff --git a/src/plugins/nonebot_bison/platform/arknights.py b/src/plugins/nonebot_bison/platform/arknights.py index fb25729..f7b0002 100644 --- a/src/plugins/nonebot_bison/platform/arknights.py +++ b/src/plugins/nonebot_bison/platform/arknights.py @@ -1,8 +1,8 @@ import json from typing import Any -from bs4 import BeautifulSoup as bs import httpx +from bs4 import BeautifulSoup as bs from ..post import Post from ..types import Category, RawPost, Target @@ -12,26 +12,28 @@ from .platform import CategoryNotSupport, NewMessage, StatusChange class Arknights(NewMessage): - categories = {1: '游戏公告'} - platform_name = 'arknights' - name = '明日方舟游戏信息' + categories = {1: "游戏公告"} + platform_name = "arknights" + name = "明日方舟游戏信息" enable_tag = False enabled = True is_common = False - schedule_type = 'interval' - schedule_kw = {'seconds': 30} + schedule_type = "interval" + schedule_kw = {"seconds": 30} has_target = False async def get_target_name(self, _: Target) -> str: - return '明日方舟游戏信息' + return "明日方舟游戏信息" async def get_sub_list(self, _) -> list[RawPost]: async with httpx.AsyncClient() as client: - raw_data = await client.get('https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/announcement.meta.json') - return json.loads(raw_data.text)['announceList'] + raw_data = await client.get( + "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/announcement.meta.json" + ) + return json.loads(raw_data.text)["announceList"] def get_id(self, post: RawPost) -> Any: - return post['announceId'] + return post["announceId"] def get_date(self, _: RawPost) -> None: return None @@ -40,64 +42,85 @@ class Arknights(NewMessage): return Category(1) async def parse(self, raw_post: RawPost) -> Post: - announce_url = raw_post['webUrl'] - text = '' + announce_url = raw_post["webUrl"] + text = "" async with httpx.AsyncClient() as client: raw_html = await client.get(announce_url) - soup = bs(raw_html, 'html.parser') + soup = bs(raw_html, "html.parser") pics = [] if soup.find("div", class_="standerd-container"): # 图文 render = Render() - viewport = {'width': 320, 'height': 6400, 'deviceScaleFactor': 3} - pic_data = await render.render(announce_url, viewport=viewport, target='div.main') + viewport = {"width": 320, "height": 6400, "deviceScaleFactor": 3} + pic_data = await render.render( + announce_url, viewport=viewport, target="div.main" + ) if pic_data: pics.append(pic_data) else: - text = '图片渲染失败' - elif (pic := soup.find('img', class_='banner-image')): - pics.append(pic['src']) + text = "图片渲染失败" + elif pic := soup.find("img", class_="banner-image"): + pics.append(pic["src"]) else: raise CategoryNotSupport() - return Post('arknights', text=text, url='', target_name="明日方舟游戏内公告", pics=pics, compress=True, override_use_pic=False) + return Post( + "arknights", + text=text, + url="", + target_name="明日方舟游戏内公告", + pics=pics, + compress=True, + override_use_pic=False, + ) + class AkVersion(StatusChange): - categories = {2: '更新信息'} - platform_name = 'arknights' - name = '明日方舟游戏信息' + categories = {2: "更新信息"} + platform_name = "arknights" + name = "明日方舟游戏信息" enable_tag = False enabled = True is_common = False - schedule_type = 'interval' - schedule_kw = {'seconds': 30} + schedule_type = "interval" + schedule_kw = {"seconds": 30} has_target = False async def get_target_name(self, _: Target) -> str: - return '明日方舟游戏信息' + return "明日方舟游戏信息" async def get_status(self, _): async with httpx.AsyncClient() as client: - res_ver = await client.get('https://ak-conf.hypergryph.com/config/prod/official/IOS/version') - res_preanounce = await client.get('https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json') + res_ver = await client.get( + "https://ak-conf.hypergryph.com/config/prod/official/IOS/version" + ) + res_preanounce = await client.get( + "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json" + ) res = res_ver.json() res.update(res_preanounce.json()) return res def compare_status(self, _, old_status, new_status): res = [] - if old_status.get('preAnnounceType') == 2 and new_status.get('preAnnounceType') == 0: - res.append(Post('arknights', - text='登录界面维护公告上线(大概是开始维护了)', - target_name='明日方舟更新信息')) - elif old_status.get('preAnnounceType') == 0 and new_status.get('preAnnounceType') == 2: - res.append(Post('arknights', - text='登录界面维护公告下线(大概是开服了,冲!)', - target_name='明日方舟更新信息')) - if old_status.get('clientVersion') != new_status.get('clientVersion'): - res.append(Post('arknights', text='游戏本体更新(大更新)', target_name='明日方舟更新信息')) - if old_status.get('resVersion') != new_status.get('resVersion'): - res.append(Post('arknights', text='游戏资源更新(小更新)', target_name='明日方舟更新信息')) + if ( + old_status.get("preAnnounceType") == 2 + and new_status.get("preAnnounceType") == 0 + ): + res.append( + Post("arknights", text="登录界面维护公告上线(大概是开始维护了)", target_name="明日方舟更新信息") + ) + elif ( + old_status.get("preAnnounceType") == 0 + and new_status.get("preAnnounceType") == 2 + ): + res.append( + Post("arknights", text="登录界面维护公告下线(大概是开服了,冲!)", target_name="明日方舟更新信息") + ) + if old_status.get("clientVersion") != new_status.get("clientVersion"): + res.append(Post("arknights", text="游戏本体更新(大更新)", target_name="明日方舟更新信息")) + if old_status.get("resVersion") != new_status.get("resVersion"): + res.append(Post("arknights", text="游戏资源更新(小更新)", target_name="明日方舟更新信息")) return res def get_category(self, _): @@ -106,28 +129,29 @@ class AkVersion(StatusChange): async def parse(self, raw_post): return raw_post + class MonsterSiren(NewMessage): - categories = {3: '塞壬唱片新闻'} - platform_name = 'arknights' - name = '明日方舟游戏信息' + categories = {3: "塞壬唱片新闻"} + platform_name = "arknights" + name = "明日方舟游戏信息" enable_tag = False enabled = True is_common = False - schedule_type = 'interval' - schedule_kw = {'seconds': 30} + schedule_type = "interval" + schedule_kw = {"seconds": 30} has_target = False async def get_target_name(self, _: Target) -> str: - return '明日方舟游戏信息' + return "明日方舟游戏信息" async def get_sub_list(self, _) -> list[RawPost]: async with httpx.AsyncClient() as client: - raw_data = await client.get('https://monster-siren.hypergryph.com/api/news') - return raw_data.json()['data']['list'] + raw_data = await client.get("https://monster-siren.hypergryph.com/api/news") + return raw_data.json()["data"]["list"] def get_id(self, post: RawPost) -> Any: - return post['cid'] + return post["cid"] def get_date(self, _) -> None: return None @@ -138,13 +162,21 @@ class MonsterSiren(NewMessage): async def parse(self, raw_post: RawPost) -> Post: url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}' async with httpx.AsyncClient() as client: - res = await client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}') + res = await client.get( + f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}' + ) raw_data = res.json() - content = raw_data['data']['content'] - content = content.replace('

', '

\n') - soup = bs(content, 'html.parser') - imgs = list(map(lambda x: x['src'], soup('img'))) + content = raw_data["data"]["content"] + content = content.replace("

", "

\n") + soup = bs(content, "html.parser") + imgs = list(map(lambda x: x["src"], soup("img"))) text = f'{raw_post["title"]}\n{soup.text.strip()}' - return Post('monster-siren', text=text, pics=imgs, - url=url, target_name="塞壬唱片新闻", compress=True, - override_use_pic=False) + return Post( + "monster-siren", + text=text, + pics=imgs, + url=url, + target_name="塞壬唱片新闻", + compress=True, + override_use_pic=False, + ) diff --git a/src/plugins/nonebot_bison/platform/bilibili.py b/src/plugins/nonebot_bison/platform/bilibili.py index 8980eb1..5ff3cc1 100644 --- a/src/plugins/nonebot_bison/platform/bilibili.py +++ b/src/plugins/nonebot_bison/platform/bilibili.py @@ -5,50 +5,57 @@ import httpx from ..post import Post from ..types import Category, RawPost, Tag, Target -from .platform import NewMessage, CategoryNotSupport +from .platform import CategoryNotSupport, NewMessage + class Bilibili(NewMessage): categories = { - 1: "一般动态", - 2: "专栏文章", - 3: "视频", - 4: "纯文字", - 5: "转发" - # 5: "短视频" - } - platform_name = 'bilibili' + 1: "一般动态", + 2: "专栏文章", + 3: "视频", + 4: "纯文字", + 5: "转发" + # 5: "短视频" + } + platform_name = "bilibili" enable_tag = True enabled = True is_common = True - schedule_type = 'interval' - schedule_kw = {'seconds': 10} - name = 'B站' + schedule_type = "interval" + schedule_kw = {"seconds": 10} + name = "B站" has_target = True async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: - res = await client.get('https://api.bilibili.com/x/space/acc/info', params={'mid': target}) + res = await client.get( + "https://api.bilibili.com/x/space/acc/info", params={"mid": target} + ) res_data = json.loads(res.text) - if res_data['code']: + if res_data["code"]: return None - return res_data['data']['name'] + return res_data["data"]["name"] async def get_sub_list(self, target: Target) -> list[RawPost]: async with httpx.AsyncClient() as client: - params = {'host_uid': target, 'offset': 0, 'need_top': 0} - res = await client.get('https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history', params=params, timeout=4.0) + params = {"host_uid": target, "offset": 0, "need_top": 0} + res = await client.get( + "https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history", + params=params, + timeout=4.0, + ) res_dict = json.loads(res.text) - if res_dict['code'] == 0: - return res_dict['data']['cards'] + if res_dict["code"] == 0: + return res_dict["data"]["cards"] else: return [] def get_id(self, post: RawPost) -> Any: - return post['desc']['dynamic_id'] - + return post["desc"]["dynamic_id"] + def get_date(self, post: RawPost) -> int: - return post['desc']['timestamp'] + return post["desc"]["timestamp"] def _do_get_category(self, post_type: int) -> Category: if post_type == 2: @@ -65,63 +72,75 @@ class Bilibili(NewMessage): raise CategoryNotSupport() def get_category(self, post: RawPost) -> Category: - post_type = post['desc']['type'] + post_type = post["desc"]["type"] return self._do_get_category(post_type) def get_tags(self, raw_post: RawPost) -> list[Tag]: - return [*map(lambda tp: tp['topic_name'], raw_post['display']['topic_info']['topic_details'])] + return [ + *map( + lambda tp: tp["topic_name"], + raw_post["display"]["topic_info"]["topic_details"], + ) + ] def _get_info(self, post_type: Category, card) -> tuple[str, list]: if post_type == 1: # 一般动态 - text = card['item']['description'] - pic = [img['img_src'] for img in card['item']['pictures']] + text = card["item"]["description"] + pic = [img["img_src"] for img in card["item"]["pictures"]] elif post_type == 2: # 专栏文章 - text = '{} {}'.format(card['title'], card['summary']) - pic = card['image_urls'] + text = "{} {}".format(card["title"], card["summary"]) + pic = card["image_urls"] elif post_type == 3: # 视频 - text = card['dynamic'] - pic = [card['pic']] + text = card["dynamic"] + pic = [card["pic"]] elif post_type == 4: # 纯文字 - text = card['item']['content'] + text = card["item"]["content"] pic = [] else: raise CategoryNotSupport() return text, pic async def parse(self, raw_post: RawPost) -> Post: - card_content = json.loads(raw_post['card']) + card_content = json.loads(raw_post["card"]) post_type = self.get_category(raw_post) - target_name = raw_post['desc']['user_profile']['info']['uname'] + target_name = raw_post["desc"]["user_profile"]["info"]["uname"] if post_type >= 1 and post_type < 5: - url = '' + url = "" if post_type == 1: # 一般动态 - url = 'https://t.bilibili.com/{}'.format(raw_post['desc']['dynamic_id_str']) + url = "https://t.bilibili.com/{}".format( + raw_post["desc"]["dynamic_id_str"] + ) elif post_type == 2: # 专栏文章 - url = 'https://www.bilibili.com/read/cv{}'.format(raw_post['desc']['rid']) + url = "https://www.bilibili.com/read/cv{}".format( + raw_post["desc"]["rid"] + ) elif post_type == 3: # 视频 - url = 'https://www.bilibili.com/video/{}'.format(raw_post['desc']['bvid']) + url = "https://www.bilibili.com/video/{}".format( + raw_post["desc"]["bvid"] + ) elif post_type == 4: # 纯文字 - url = 'https://t.bilibili.com/{}'.format(raw_post['desc']['dynamic_id_str']) + url = "https://t.bilibili.com/{}".format( + raw_post["desc"]["dynamic_id_str"] + ) text, pic = self._get_info(post_type, card_content) elif post_type == 5: # 转发 - url = 'https://t.bilibili.com/{}'.format(raw_post['desc']['dynamic_id_str']) - text = card_content['item']['content'] - orig_type = card_content['item']['orig_type'] - orig = json.loads(card_content['origin']) + url = "https://t.bilibili.com/{}".format(raw_post["desc"]["dynamic_id_str"]) + text = card_content["item"]["content"] + orig_type = card_content["item"]["orig_type"] + orig = json.loads(card_content["origin"]) orig_text, _ = self._get_info(self._do_get_category(orig_type), orig) - text += '\n--------------\n' + text += "\n--------------\n" text += orig_text pic = [] else: raise CategoryNotSupport(post_type) - return Post('bilibili', text=text, url=url, pics=pic, target_name=target_name) - + return Post("bilibili", text=text, url=url, pics=pic, target_name=target_name) diff --git a/src/plugins/nonebot_bison/platform/ncm_artist.py b/src/plugins/nonebot_bison/platform/ncm_artist.py index e230f65..a30072f 100644 --- a/src/plugins/nonebot_bison/platform/ncm_artist.py +++ b/src/plugins/nonebot_bison/platform/ncm_artist.py @@ -1,54 +1,58 @@ from typing import Any, Optional import httpx + from ..post import Post from ..types import RawPost, Target from .platform import NewMessage + class NcmArtist(NewMessage): categories = {} - platform_name = 'ncm-artist' + platform_name = "ncm-artist" enable_tag = False enabled = True is_common = True - schedule_type = 'interval' - schedule_kw = {'minutes': 1} + schedule_type = "interval" + schedule_kw = {"minutes": 1} name = "网易云-歌手" has_target = True async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: res = await client.get( - "https://music.163.com/api/artist/albums/{}".format(target), - headers={'Referer': 'https://music.163.com/'} - ) + "https://music.163.com/api/artist/albums/{}".format(target), + headers={"Referer": "https://music.163.com/"}, + ) res_data = res.json() - if res_data['code'] != 200: + if res_data["code"] != 200: return - return res_data['artist']['name'] + return res_data["artist"]["name"] async def get_sub_list(self, target: Target) -> list[RawPost]: async with httpx.AsyncClient() as client: res = await client.get( - "https://music.163.com/api/artist/albums/{}".format(target), - headers={'Referer': 'https://music.163.com/'} - ) + "https://music.163.com/api/artist/albums/{}".format(target), + headers={"Referer": "https://music.163.com/"}, + ) res_data = res.json() - if res_data['code'] != 200: + if res_data["code"] != 200: return [] else: - return res_data['hotAlbums'] + return res_data["hotAlbums"] def get_id(self, post: RawPost) -> Any: - return post['id'] + return post["id"] def get_date(self, post: RawPost) -> int: - return post['publishTime'] // 1000 + return post["publishTime"] // 1000 async def parse(self, raw_post: RawPost) -> Post: - text = '新专辑发布:{}'.format(raw_post['name']) - target_name = raw_post['artist']['name'] - pics = [raw_post['picUrl']] - url = "https://music.163.com/#/album?id={}".format(raw_post['id']) - return Post('ncm-artist', text=text, url=url, pics=pics, target_name=target_name) + text = "新专辑发布:{}".format(raw_post["name"]) + target_name = raw_post["artist"]["name"] + pics = [raw_post["picUrl"]] + url = "https://music.163.com/#/album?id={}".format(raw_post["id"]) + return Post( + "ncm-artist", text=text, url=url, pics=pics, target_name=target_name + ) diff --git a/src/plugins/nonebot_bison/platform/ncm_radio.py b/src/plugins/nonebot_bison/platform/ncm_radio.py index 6fae725..20abb52 100644 --- a/src/plugins/nonebot_bison/platform/ncm_radio.py +++ b/src/plugins/nonebot_bison/platform/ncm_radio.py @@ -1,56 +1,58 @@ from typing import Any, Optional import httpx + from ..post import Post from ..types import RawPost, Target from .platform import NewMessage + class NcmRadio(NewMessage): categories = {} - platform_name = 'ncm-radio' + platform_name = "ncm-radio" enable_tag = False enabled = True is_common = False - schedule_type = 'interval' - schedule_kw = {'minutes': 10} + schedule_type = "interval" + schedule_kw = {"minutes": 10} name = "网易云-电台" has_target = True async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: res = await client.post( - "http://music.163.com/api/dj/program/byradio", - headers={'Referer': 'https://music.163.com/'}, - data={"radioId": target, "limit": 1000, "offset": 0} - ) + "http://music.163.com/api/dj/program/byradio", + headers={"Referer": "https://music.163.com/"}, + data={"radioId": target, "limit": 1000, "offset": 0}, + ) res_data = res.json() - if res_data['code'] != 200 or res_data['programs'] == 0: + if res_data["code"] != 200 or res_data["programs"] == 0: return - return res_data['programs'][0]['radio']['name'] + return res_data["programs"][0]["radio"]["name"] async def get_sub_list(self, target: Target) -> list[RawPost]: async with httpx.AsyncClient() as client: res = await client.post( - "http://music.163.com/api/dj/program/byradio", - headers={'Referer': 'https://music.163.com/'}, - data={"radioId": target, "limit": 1000, "offset": 0} - ) + "http://music.163.com/api/dj/program/byradio", + headers={"Referer": "https://music.163.com/"}, + data={"radioId": target, "limit": 1000, "offset": 0}, + ) res_data = res.json() - if res_data['code'] != 200: + if res_data["code"] != 200: return [] else: - return res_data['programs'] + return res_data["programs"] def get_id(self, post: RawPost) -> Any: - return post['id'] + return post["id"] def get_date(self, post: RawPost) -> int: - return post['createTime'] // 1000 + return post["createTime"] // 1000 async def parse(self, raw_post: RawPost) -> Post: - text = '网易云电台更新:{}'.format(raw_post['name']) - target_name = raw_post['radio']['name'] - pics = [raw_post['coverUrl']] - url = "https://music.163.com/#/program/{}".format(raw_post['id']) - return Post('ncm-radio', text=text, url=url, pics=pics, target_name=target_name) + text = "网易云电台更新:{}".format(raw_post["name"]) + target_name = raw_post["radio"]["name"] + pics = [raw_post["coverUrl"]] + url = "https://music.163.com/#/program/{}".format(raw_post["id"]) + return Post("ncm-radio", text=text, url=url, pics=pics, target_name=target_name) diff --git a/src/plugins/nonebot_bison/platform/platform.py b/src/plugins/nonebot_bison/platform/platform.py index cff3ff6..1bec3d8 100644 --- a/src/plugins/nonebot_bison/platform/platform.py +++ b/src/plugins/nonebot_bison/platform/platform.py @@ -1,8 +1,8 @@ -from abc import abstractmethod, ABC +import time +from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass -import time -from typing import Any, Collection, Optional, Literal +from typing import Any, Collection, Literal, Optional import httpx from nonebot import logger @@ -17,26 +17,27 @@ class CategoryNotSupport(Exception): class RegistryMeta(type): - def __new__(cls, name, bases, namespace, **kwargs): return super().__new__(cls, name, bases, namespace) def __init__(cls, name, bases, namespace, **kwargs): - if kwargs.get('base'): + if kwargs.get("base"): # this is the base class cls.registry = [] - elif not kwargs.get('abstract'): + elif not kwargs.get("abstract"): # this is the subclass cls.registry.append(cls) super().__init__(name, bases, namespace, **kwargs) + class RegistryABCMeta(RegistryMeta, ABC): ... + class Platform(metaclass=RegistryABCMeta, base=True): - - schedule_type: Literal['date', 'interval', 'cron'] + + schedule_type: Literal["date", "interval", "cron"] schedule_kw: dict is_common: bool enabled: bool @@ -52,7 +53,9 @@ class Platform(metaclass=RegistryABCMeta, base=True): ... @abstractmethod - async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: + async def fetch_new_post( + self, target: Target, users: list[UserSubInfo] + ) -> list[tuple[User, list[Post]]]: ... @abstractmethod @@ -67,7 +70,7 @@ class Platform(metaclass=RegistryABCMeta, base=True): super().__init__() self.reverse_category = {} for key, val in self.categories.items(): - self.reverse_category[val] = key + self.reverse_category[val] = key self.store = dict() @abstractmethod @@ -75,12 +78,14 @@ class Platform(metaclass=RegistryABCMeta, base=True): "Return Tag list of given RawPost" def get_stored_data(self, target: Target) -> Any: - return self.store.get(target) + return self.store.get(target) def set_stored_data(self, target: Target, data: Any): self.store[target] = data - async def filter_user_custom(self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]) -> list[RawPost]: + async def filter_user_custom( + self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag] + ) -> list[RawPost]: res: list[RawPost] = [] for raw_post in raw_post_list: if self.categories: @@ -99,12 +104,16 @@ class Platform(metaclass=RegistryABCMeta, base=True): res.append(raw_post) return res - async def dispatch_user_post(self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: + async def dispatch_user_post( + self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo] + ) -> list[tuple[User, list[Post]]]: res: list[tuple[User, list[Post]]] = [] for user, category_getter, tag_getter in users: required_tags = tag_getter(target) if self.enable_tag else [] cats = category_getter(target) - user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags) + user_raw_post = await self.filter_user_custom( + new_posts, cats, required_tags + ) user_post: list[Post] = [] for raw_post in user_raw_post: user_post.append(await self.do_parse(raw_post)) @@ -116,6 +125,7 @@ class Platform(metaclass=RegistryABCMeta, base=True): "Return category of given Rawpost" raise NotImplementedError() + class MessageProcess(Platform, abstract=True): "General message process fetch, parse, filter progress" @@ -127,7 +137,6 @@ class MessageProcess(Platform, abstract=True): def get_id(self, post: RawPost) -> Any: "Get post id of given RawPost" - async def do_parse(self, raw_post: RawPost) -> Post: post_id = self.get_id(raw_post) if post_id not in self.parse_cache: @@ -156,8 +165,11 @@ class MessageProcess(Platform, abstract=True): # post_id = self.get_id(raw_post) # if post_id in exists_posts_set: # continue - if (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60 and \ - plugin_config.bison_init_filter: + if ( + (post_time := self.get_date(raw_post)) + and time.time() - post_time > 2 * 60 * 60 + and plugin_config.bison_init_filter + ): continue try: self.get_category(raw_post) @@ -168,15 +180,18 @@ class MessageProcess(Platform, abstract=True): res.append(raw_post) return res + class NewMessage(MessageProcess, abstract=True): "Fetch a list of messages, filter the new messages, dispatch it to different users" @dataclass - class MessageStorage(): + class MessageStorage: inited: bool exists_posts: set[Any] - async def filter_common_with_diff(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]: + async def filter_common_with_diff( + self, target: Target, raw_post_list: list[RawPost] + ) -> list[RawPost]: filtered_post = await self.filter_common(raw_post_list) store = self.get_stored_data(target) or self.MessageStorage(False, set()) res = [] @@ -185,7 +200,11 @@ class NewMessage(MessageProcess, abstract=True): for raw_post in filtered_post: post_id = self.get_id(raw_post) store.exists_posts.add(post_id) - logger.info('init {}-{} with {}'.format(self.platform_name, target, store.exists_posts)) + logger.info( + "init {}-{} with {}".format( + self.platform_name, target, store.exists_posts + ) + ) store.inited = True else: for raw_post in filtered_post: @@ -197,8 +216,9 @@ class NewMessage(MessageProcess, abstract=True): self.set_stored_data(target, store) return res - - async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: + async def fetch_new_post( + self, target: Target, users: list[UserSubInfo] + ) -> list[tuple[User, list[Post]]]: try: post_list = await self.get_sub_list(target) new_posts = await self.filter_common_with_diff(target, post_list) @@ -206,17 +226,25 @@ class NewMessage(MessageProcess, abstract=True): return [] else: for post in new_posts: - logger.info('fetch new post from {} {}: {}'.format( - self.platform_name, - target if self.has_target else '-', - self.get_id(post))) + logger.info( + "fetch new post from {} {}: {}".format( + self.platform_name, + target if self.has_target else "-", + self.get_id(post), + ) + ) res = await self.dispatch_user_post(target, new_posts, users) self.parse_cache = {} return res except httpx.RequestError as err: - logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) + logger.warning( + "network connection error: {}, url: {}".format( + type(err), err.request.url + ) + ) return [] + class StatusChange(Platform, abstract=True): "Watch a status, and fire a post when status changes" @@ -232,49 +260,69 @@ class StatusChange(Platform, abstract=True): async def parse(self, raw_post: RawPost) -> Post: ... - async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: + async def fetch_new_post( + self, target: Target, users: list[UserSubInfo] + ) -> list[tuple[User, list[Post]]]: try: new_status = await self.get_status(target) res = [] if old_status := self.get_stored_data(target): diff = self.compare_status(target, old_status, new_status) if diff: - logger.info("status changes {} {}: {} -> {}".format( - self.platform_name, - target if self.has_target else '-', - old_status, new_status - )) + logger.info( + "status changes {} {}: {} -> {}".format( + self.platform_name, + target if self.has_target else "-", + old_status, + new_status, + ) + ) res = await self.dispatch_user_post(target, diff, users) self.set_stored_data(target, new_status) return res except httpx.RequestError as err: - logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) + logger.warning( + "network connection error: {}, url: {}".format( + type(err), err.request.url + ) + ) return [] + class SimplePost(MessageProcess, abstract=True): "Fetch a list of messages, dispatch it to different users" - async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]: + async def fetch_new_post( + self, target: Target, users: list[UserSubInfo] + ) -> list[tuple[User, list[Post]]]: try: new_posts = await self.get_sub_list(target) if not new_posts: return [] else: for post in new_posts: - logger.info('fetch new post from {} {}: {}'.format( - self.platform_name, - target if self.has_target else '-', - self.get_id(post))) + logger.info( + "fetch new post from {} {}: {}".format( + self.platform_name, + target if self.has_target else "-", + self.get_id(post), + ) + ) res = await self.dispatch_user_post(target, new_posts, users) self.parse_cache = {} return res except httpx.RequestError as err: - logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) + logger.warning( + "network connection error: {}, url: {}".format( + type(err), err.request.url + ) + ) return [] + class NoTargetGroup(Platform, abstract=True): enable_tag = False - DUMMY_STR = '_DUMMY' + DUMMY_STR = "_DUMMY" enabled = True has_target = False @@ -287,24 +335,35 @@ class NoTargetGroup(Platform, abstract=True): self.schedule_kw = platform_list[0].schedule_kw for platform in platform_list: if platform.has_target: - raise RuntimeError('Platform {} should have no target'.format(platform.name)) + raise RuntimeError( + "Platform {} should have no target".format(platform.name) + ) if name == self.DUMMY_STR: name = platform.name elif name != platform.name: - raise RuntimeError('Platform name for {} not fit'.format(self.platform_name)) + raise RuntimeError( + "Platform name for {} not fit".format(self.platform_name) + ) platform_category_key_set = set(platform.categories.keys()) if platform_category_key_set & categories_keys: - raise RuntimeError('Platform categories for {} duplicate'.format(self.platform_name)) + raise RuntimeError( + "Platform categories for {} duplicate".format(self.platform_name) + ) categories_keys |= platform_category_key_set self.categories.update(platform.categories) - if platform.schedule_kw != self.schedule_kw or platform.schedule_type != self.schedule_type: - raise RuntimeError('Platform scheduler for {} not fit'.format(self.platform_name)) + if ( + platform.schedule_kw != self.schedule_kw + or platform.schedule_type != self.schedule_type + ): + raise RuntimeError( + "Platform scheduler for {} not fit".format(self.platform_name) + ) self.name = name self.is_common = platform_list[0].is_common super().__init__() def __str__(self): - return '[' + ' '.join(map(lambda x: x.name, self.platform_list)) + ']' + return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]" async def get_target_name(self, _): return await self.platform_list[0].get_target_name(_) @@ -316,4 +375,3 @@ class NoTargetGroup(Platform, abstract=True): for user, posts in platform_res: res[user].extend(posts) return [[key, val] for key, val in res.items()] - diff --git a/src/plugins/nonebot_bison/platform/rss.py b/src/plugins/nonebot_bison/platform/rss.py index 4cc18cc..330d93d 100644 --- a/src/plugins/nonebot_bison/platform/rss.py +++ b/src/plugins/nonebot_bison/platform/rss.py @@ -1,31 +1,32 @@ import calendar from typing import Any, Optional -from bs4 import BeautifulSoup as bs import feedparser import httpx +from bs4 import BeautifulSoup as bs from ..post import Post from ..types import RawPost, Target from .platform import NewMessage + class Rss(NewMessage): categories = {} enable_tag = False - platform_name = 'rss' + platform_name = "rss" name = "Rss" enabled = True is_common = True - schedule_type = 'interval' - schedule_kw = {'seconds': 30} + schedule_type = "interval" + schedule_kw = {"seconds": 30} has_target = True async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: res = await client.get(target, timeout=10.0) feed = feedparser.parse(res.text) - return feed['feed']['title'] + return feed["feed"]["title"] def get_date(self, post: RawPost) -> int: return calendar.timegm(post.published_parsed) @@ -39,12 +40,18 @@ class Rss(NewMessage): feed = feedparser.parse(res) entries = feed.entries for entry in entries: - entry['_target_name'] = feed.feed.title + entry["_target_name"] = feed.feed.title return feed.entries async def parse(self, raw_post: RawPost) -> Post: - text = raw_post.get('title', '') + '\n' if raw_post.get('title') else '' - soup = bs(raw_post.description, 'html.parser') + text = raw_post.get("title", "") + "\n" if raw_post.get("title") else "" + soup = bs(raw_post.description, "html.parser") text += soup.text.strip() - pics = list(map(lambda x: x.attrs['src'], soup('img'))) - return Post('rss', text=text, url=raw_post.link, pics=pics, target_name=raw_post['_target_name']) + pics = list(map(lambda x: x.attrs["src"], soup("img"))) + return Post( + "rss", + text=text, + url=raw_post.link, + pics=pics, + target_name=raw_post["_target_name"], + ) diff --git a/src/plugins/nonebot_bison/platform/wechat.py b/src/plugins/nonebot_bison/platform/wechat.py index 7c04306..d5f5487 100644 --- a/src/plugins/nonebot_bison/platform/wechat.py +++ b/src/plugins/nonebot_bison/platform/wechat.py @@ -1,14 +1,15 @@ -from datetime import datetime import hashlib import json import re +from datetime import datetime from typing import Any, Optional -from bs4 import BeautifulSoup as bs import httpx +from bs4 import BeautifulSoup as bs from ..post import Post from ..types import * + # from .platform import Platform @@ -75,4 +76,3 @@ from ..types import * # pics=[], # url='' # ) - diff --git a/src/plugins/nonebot_bison/platform/weibo.py b/src/plugins/nonebot_bison/platform/weibo.py index 19d8703..365e3b2 100644 --- a/src/plugins/nonebot_bison/platform/weibo.py +++ b/src/plugins/nonebot_bison/platform/weibo.py @@ -1,121 +1,152 @@ -from datetime import datetime import json import re +from datetime import datetime from typing import Any, Optional -from bs4 import BeautifulSoup as bs import httpx +from bs4 import BeautifulSoup as bs from nonebot import logger from ..post import Post from ..types import * from .platform import NewMessage + class Weibo(NewMessage): categories = { - 1: '转发', - 2: '视频', - 3: '图文', - 4: '文字', - } + 1: "转发", + 2: "视频", + 3: "图文", + 4: "文字", + } enable_tag = True - platform_name = 'weibo' - name = '新浪微博' + platform_name = "weibo" + name = "新浪微博" enabled = True is_common = True - schedule_type = 'interval' - schedule_kw = {'seconds': 3} + schedule_type = "interval" + schedule_kw = {"seconds": 3} has_target = True async def get_target_name(self, target: Target) -> Optional[str]: async with httpx.AsyncClient() as client: - param = {'containerid': '100505' + target} - res = await client.get('https://m.weibo.cn/api/container/getIndex', params=param) + param = {"containerid": "100505" + target} + res = await client.get( + "https://m.weibo.cn/api/container/getIndex", params=param + ) res_dict = json.loads(res.text) - if res_dict.get('ok') == 1: - return res_dict['data']['userInfo']['screen_name'] + if res_dict.get("ok") == 1: + return res_dict["data"]["userInfo"]["screen_name"] else: return None async def get_sub_list(self, target: Target) -> list[RawPost]: async with httpx.AsyncClient() as client: - params = { 'containerid': '107603' + target} - res = await client.get('https://m.weibo.cn/api/container/getIndex?', params=params, timeout=4.0) + params = {"containerid": "107603" + target} + res = await client.get( + "https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0 + ) res_data = json.loads(res.text) - if not res_data['ok']: + if not res_data["ok"]: return [] - custom_filter: Callable[[RawPost], bool] = lambda d: d['card_type'] == 9 - return list(filter(custom_filter, res_data['data']['cards'])) + custom_filter: Callable[[RawPost], bool] = lambda d: d["card_type"] == 9 + return list(filter(custom_filter, res_data["data"]["cards"])) def get_id(self, post: RawPost) -> Any: - return post['mblog']['id'] + return post["mblog"]["id"] def filter_platform_custom(self, raw_post: RawPost) -> bool: - return raw_post['card_type'] == 9 + return raw_post["card_type"] == 9 def get_date(self, raw_post: RawPost) -> float: - created_time = datetime.strptime(raw_post['mblog']['created_at'], '%a %b %d %H:%M:%S %z %Y') + created_time = datetime.strptime( + raw_post["mblog"]["created_at"], "%a %b %d %H:%M:%S %z %Y" + ) return created_time.timestamp() def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]: "Return Tag list of given RawPost" - text = raw_post['mblog']['text'] - soup = bs(text, 'html.parser') - res = list(map( - lambda x: x[1:-1], - filter( - lambda s: s[0] == '#' and s[-1] == '#', - map(lambda x:x.text, soup.find_all('span', class_='surl-text')) - ) - )) - super_topic_img = soup.find('img', src=re.compile(r'timeline_card_small_super_default')) + text = raw_post["mblog"]["text"] + soup = bs(text, "html.parser") + res = list( + map( + lambda x: x[1:-1], + filter( + lambda s: s[0] == "#" and s[-1] == "#", + map(lambda x: x.text, soup.find_all("span", class_="surl-text")), + ), + ) + ) + super_topic_img = soup.find( + "img", src=re.compile(r"timeline_card_small_super_default") + ) if super_topic_img: try: - res.append(super_topic_img.parent.parent.find('span', class_='surl-text').text + '超话') + res.append( + super_topic_img.parent.parent.find("span", class_="surl-text").text + + "超话" + ) except: - logger.info('super_topic extract error: {}'.format(text)) + logger.info("super_topic extract error: {}".format(text)) return res def get_category(self, raw_post: RawPost) -> Category: - if raw_post['mblog'].get('retweeted_status'): + if raw_post["mblog"].get("retweeted_status"): return Category(1) - elif raw_post['mblog'].get('page_info') and raw_post['mblog']['page_info'].get('type') == 'video': + elif ( + raw_post["mblog"].get("page_info") + and raw_post["mblog"]["page_info"].get("type") == "video" + ): return Category(2) - elif raw_post['mblog'].get('pics'): + elif raw_post["mblog"].get("pics"): return Category(3) else: return Category(4) def _get_text(self, raw_text: str) -> str: - text = raw_text.replace('
', '\n') - return bs(text, 'html.parser').text + text = raw_text.replace("
", "\n") + return bs(text, "html.parser").text async def parse(self, raw_post: RawPost) -> Post: header = { - 'accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9', - 'accept-language': 'zh-CN,zh;q=0.9', - 'authority': 'm.weibo.cn', - 'cache-control': 'max-age=0', - 'sec-fetch-dest': 'empty', - 'sec-fetch-mode': 'same-origin', - 'sec-fetch-site': 'same-origin', - 'upgrade-insecure-requests': '1', - 'user-agent': 'Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) ' - 'AppleWebKit/537.36 (KHTML, like Gecko) Chrome/89.0.4389.72 ' - 'Mobile Safari/537.36'} - info = raw_post['mblog'] - if info['isLongText'] or info['pic_num'] > 9: + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", + "accept-language": "zh-CN,zh;q=0.9", + "authority": "m.weibo.cn", + "cache-control": "max-age=0", + "sec-fetch-dest": "empty", + "sec-fetch-mode": "same-origin", + "sec-fetch-site": "same-origin", + "upgrade-insecure-requests": "1", + "user-agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) " + "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/89.0.4389.72 " + "Mobile Safari/537.36", + } + info = raw_post["mblog"] + if info["isLongText"] or info["pic_num"] > 9: async with httpx.AsyncClient() as client: - res = await client.get('https://m.weibo.cn/detail/{}'.format(info['mid']), headers=header) + res = await client.get( + "https://m.weibo.cn/detail/{}".format(info["mid"]), headers=header + ) try: - full_json_text = re.search(r'"status": ([\s\S]+),\s+"hotScheme"', res.text).group(1) + full_json_text = re.search( + r'"status": ([\s\S]+),\s+"hotScheme"', res.text + ).group(1) info = json.loads(full_json_text) except: - logger.info('detail message error: https://m.weibo.cn/detail/{}'.format(info['mid'])) - parsed_text = self._get_text(info['text']) - pic_urls = [img['large']['url'] for img in info.get('pics', [])] - detail_url = 'https://weibo.com/{}/{}'.format(info['user']['id'], info['bid']) + logger.info( + "detail message error: https://m.weibo.cn/detail/{}".format( + info["mid"] + ) + ) + parsed_text = self._get_text(info["text"]) + pic_urls = [img["large"]["url"] for img in info.get("pics", [])] + detail_url = "https://weibo.com/{}/{}".format(info["user"]["id"], info["bid"]) # return parsed_text, detail_url, pic_urls - return Post('weibo', text=parsed_text, url=detail_url, pics=pic_urls, target_name=info['user']['screen_name']) - + return Post( + "weibo", + text=parsed_text, + url=detail_url, + pics=pic_urls, + target_name=info["user"]["screen_name"], + ) diff --git a/src/plugins/nonebot_bison/plugin_config.py b/src/plugins/nonebot_bison/plugin_config.py index e477b83..2e48b4e 100644 --- a/src/plugins/nonebot_bison/plugin_config.py +++ b/src/plugins/nonebot_bison/plugin_config.py @@ -1,23 +1,25 @@ +import warnings + +import nonebot from pydantic import BaseSettings -import warnings -import nonebot class PlugConfig(BaseSettings): bison_config_path: str = "" bison_use_pic: bool = False bison_use_local: bool = False - bison_browser: str = '' + bison_browser: str = "" bison_init_filter: bool = True bison_use_queue: bool = True - bison_outer_url: str = 'http://localhost:8080/bison/' + bison_outer_url: str = "http://localhost:8080/bison/" bison_filter_log: bool = False class Config: - extra = 'ignore' + extra = "ignore" + global_config = nonebot.get_driver().config plugin_config = PlugConfig(**global_config.dict()) if plugin_config.bison_use_local: - warnings.warn('BISON_USE_LOCAL is deprecated, please use BISON_BROWSER') + warnings.warn("BISON_USE_LOCAL is deprecated, please use BISON_BROWSER") diff --git a/src/plugins/nonebot_bison/post.py b/src/plugins/nonebot_bison/post.py index 8a77a3d..604295d 100644 --- a/src/plugins/nonebot_bison/post.py +++ b/src/plugins/nonebot_bison/post.py @@ -4,14 +4,15 @@ from functools import reduce from io import BytesIO from typing import Optional, Union -from PIL import Image import httpx from nonebot import logger from nonebot.adapters.cqhttp.message import Message, MessageSegment +from PIL import Image from .plugin_config import plugin_config from .utils import parse_text + @dataclass class Post: @@ -21,7 +22,9 @@ class Post: target_name: Optional[str] = None compress: bool = False override_use_pic: Optional[bool] = None - pics: Union[list[Union[str,bytes]], list[str], list[bytes]] = field(default_factory=list) + pics: Union[list[Union[str, bytes]], list[str], list[bytes]] = field( + default_factory=list + ) extra_msg: list[Message] = field(default_factory=list) _message: Optional[list] = None @@ -56,7 +59,7 @@ class Post: cur_img = await self._pic_url_to_image(self.pics[i]) if not self._check_image_square(cur_img.size): return - if cur_img.size[1] != images[0].size[1]: # height not equal + if cur_img.size[1] != images[0].size[1]: # height not equal return images.append(cur_img) _tmp = 0 @@ -65,6 +68,7 @@ class Post: _tmp += images[i].size[0] x_coord.append(_tmp) y_coord = [0, first_image.size[1]] + async def process_row(row: int) -> bool: if len(self.pics) < (row + 1) * 3: return False @@ -86,44 +90,48 @@ class Post: images.extend(image_row) y_coord.append(y_coord[-1] + row_first_img.size[1]) return True + if await process_row(1): - matrix = (3,2) + matrix = (3, 2) else: - matrix = (3,1) + matrix = (3, 1) if await process_row(2): - matrix = (3,3) - logger.info('trigger merge image') - target = Image.new('RGB', (x_coord[-1], y_coord[-1])) + matrix = (3, 3) + logger.info("trigger merge image") + target = Image.new("RGB", (x_coord[-1], y_coord[-1])) for y in range(matrix[1]): for x in range(matrix[0]): - target.paste(images[y * matrix[0] + x], ( - x_coord[x], y_coord[y], x_coord[x+1], y_coord[y+1] - )) + target.paste( + images[y * matrix[0] + x], + (x_coord[x], y_coord[y], x_coord[x + 1], y_coord[y + 1]), + ) target_io = BytesIO() - target.save(target_io, 'JPEG') - self.pics = self.pics[matrix[0] * matrix[1]: ] + target.save(target_io, "JPEG") + self.pics = self.pics[matrix[0] * matrix[1] :] self.pics.insert(0, target_io.getvalue()) async def generate_messages(self): if self._message is None: await self._pic_merge() msgs = [] - text = '' + text = "" if self.text: if self._use_pic(): - text += '{}'.format(self.text) + text += "{}".format(self.text) else: - text += '{}'.format(self.text if len(self.text) < 500 else self.text[:500] + '...') - text += '\n来源: {}'.format(self.target_type) + text += "{}".format( + self.text if len(self.text) < 500 else self.text[:500] + "..." + ) + text += "\n来源: {}".format(self.target_type) if self.target_name: - text += ' {}'.format(self.target_name) + text += " {}".format(self.target_name) if self._use_pic(): msgs.append(await parse_text(text)) - if not self.target_type == 'rss' and self.url: + if not self.target_type == "rss" and self.url: msgs.append(MessageSegment.text(self.url)) else: if self.url: - text += ' \n详情: {}'.format(self.url) + text += " \n详情: {}".format(self.url) msgs.append(MessageSegment.text(text)) for pic in self.pics: # if isinstance(pic, bytes): @@ -137,10 +145,17 @@ class Post: return self._message def __str__(self): - return 'type: {}\nfrom: {}\ntext: {}\nurl: {}\npic: {}'.format( - self.target_type, - self.target_name, - self.text if len(self.text) < 500 else self.text[:500] + '...', - self.url, - ', '.join(map(lambda x: 'b64img' if isinstance(x, bytes) or x.startswith('base64') else x, self.pics)) - ) + return "type: {}\nfrom: {}\ntext: {}\nurl: {}\npic: {}".format( + self.target_type, + self.target_name, + self.text if len(self.text) < 500 else self.text[:500] + "...", + self.url, + ", ".join( + map( + lambda x: "b64img" + if isinstance(x, bytes) or x.startswith("base64") + else x, + self.pics, + ) + ), + ) diff --git a/src/plugins/nonebot_bison/scheduler.py b/src/plugins/nonebot_bison/scheduler.py index 949504c..9e8eaf7 100644 --- a/src/plugins/nonebot_bison/scheduler.py +++ b/src/plugins/nonebot_bison/scheduler.py @@ -1,68 +1,91 @@ import asyncio import logging -from apscheduler.schedulers.asyncio import AsyncIOScheduler -import logging import nonebot +from apscheduler.schedulers.asyncio import AsyncIOScheduler from nonebot import get_driver, logger from nonebot.log import LoguruHandler from .config import Config from .platform import platform_manager -from .send import do_send_msgs -from .send import send_msgs -from .types import UserSubInfo from .plugin_config import plugin_config +from .send import do_send_msgs, send_msgs +from .types import UserSubInfo scheduler = AsyncIOScheduler() + @get_driver().on_startup async def _start(): scheduler.configure({"apscheduler.timezone": "Asia/Shanghai"}) scheduler.start() + # get_driver().on_startup(_start) + async def fetch_and_send(target_type: str): config = Config() target = config.get_next_target(target_type) if not target: return - logger.debug('try to fecth new posts from {}, target: {}'.format(target_type, target)) + logger.debug( + "try to fecth new posts from {}, target: {}".format(target_type, target) + ) send_user_list = config.target_user_cache[target_type][target] - send_userinfo_list = list(map( - lambda user: UserSubInfo( - user, - lambda target: config.get_sub_category(target_type, target, user.user_type, user.user), - lambda target: config.get_sub_tags(target_type, target, user.user_type, user.user) - ), send_user_list)) + send_userinfo_list = list( + map( + lambda user: UserSubInfo( + user, + lambda target: config.get_sub_category( + target_type, target, user.user_type, user.user + ), + lambda target: config.get_sub_tags( + target_type, target, user.user_type, user.user + ), + ), + send_user_list, + ) + ) bot_list = list(nonebot.get_bots().values()) bot = bot_list[0] if bot_list else None - to_send = await platform_manager[target_type].fetch_new_post(target, send_userinfo_list) + to_send = await platform_manager[target_type].fetch_new_post( + target, send_userinfo_list + ) for user, send_list in to_send: for send_post in send_list: - logger.info('send to {}: {}'.format(user, send_post)) + logger.info("send to {}: {}".format(user, send_post)) if not bot: - logger.warning('no bot connected') + logger.warning("no bot connected") else: - await send_msgs(bot, user.user, user.user_type, await send_post.generate_messages()) + await send_msgs( + bot, user.user, user.user_type, await send_post.generate_messages() + ) + for platform_name, platform in platform_manager.items(): - if platform.schedule_type in ['cron', 'interval', 'date']: - logger.info(f'start scheduler for {platform_name} with {platform.schedule_type} {platform.schedule_kw}') + if platform.schedule_type in ["cron", "interval", "date"]: + logger.info( + f"start scheduler for {platform_name} with {platform.schedule_type} {platform.schedule_kw}" + ) scheduler.add_job( - fetch_and_send, platform.schedule_type, **platform.schedule_kw, - args=(platform_name,)) + fetch_and_send, + platform.schedule_type, + **platform.schedule_kw, + args=(platform_name,), + ) + class CustomLogHandler(LoguruHandler): - def filter(self, record: logging.LogRecord): - return record.msg != ('Execution of job "%s" ' - 'skipped: maximum number of running instances reached (%d)') + return record.msg != ( + 'Execution of job "%s" ' + "skipped: maximum number of running instances reached (%d)" + ) if plugin_config.bison_use_queue: - scheduler.add_job(do_send_msgs, 'interval', seconds=0.3, coalesce=True) + scheduler.add_job(do_send_msgs, "interval", seconds=0.3, coalesce=True) aps_logger = logging.getLogger("apscheduler") aps_logger.setLevel(30) diff --git a/src/plugins/nonebot_bison/send.py b/src/plugins/nonebot_bison/send.py index 023de7b..8abeece 100644 --- a/src/plugins/nonebot_bison/send.py +++ b/src/plugins/nonebot_bison/send.py @@ -8,11 +8,13 @@ from .plugin_config import plugin_config QUEUE = [] LAST_SEND_TIME = time.time() -async def _do_send(bot: 'Bot', user: str, user_type: str, msg): - if user_type == 'group': - await bot.call_api('send_group_msg', group_id=user, message=msg) - elif user_type == 'private': - await bot.call_api('send_private_msg', user_id=user, message=msg) + +async def _do_send(bot: "Bot", user: str, user_type: str, msg): + if user_type == "group": + await bot.call_api("send_group_msg", group_id=user, message=msg) + elif user_type == "private": + await bot.call_api("send_private_msg", user_id=user, message=msg) + async def do_send_msgs(): global LAST_SEND_TIME @@ -28,10 +30,11 @@ async def do_send_msgs(): else: msg_str = str(msg) if len(msg_str) > 50: - msg_str = msg_str[:50] + '...' - logger.warning(f'send msg err {e} {msg_str}') + msg_str = msg_str[:50] + "..." + logger.warning(f"send msg err {e} {msg_str}") LAST_SEND_TIME = time.time() + async def send_msgs(bot, user, user_type, msgs): if plugin_config.bison_use_queue: for msg in msgs: @@ -39,5 +42,3 @@ async def send_msgs(bot, user, user_type, msgs): else: for msg in msgs: await _do_send(bot, user, user_type, msg) - - diff --git a/src/plugins/nonebot_bison/types.py b/src/plugins/nonebot_bison/types.py index 089b2b0..8073fa8 100644 --- a/src/plugins/nonebot_bison/types.py +++ b/src/plugins/nonebot_bison/types.py @@ -1,16 +1,18 @@ -from typing import Any, Callable, NamedTuple, NewType from dataclasses import dataclass +from typing import Any, Callable, NamedTuple, NewType + +RawPost = NewType("RawPost", Any) +Target = NewType("Target", str) +Category = NewType("Category", int) +Tag = NewType("Tag", str) -RawPost = NewType('RawPost', Any) -Target = NewType('Target', str) -Category = NewType('Category', int) -Tag = NewType('Tag', str) @dataclass(eq=True, frozen=True) class User: user: str user_type: str + class UserSubInfo(NamedTuple): user: User category_getter: Callable[[Target], list[Category]] diff --git a/tests/conftest.py b/tests/conftest.py index c2bce56..c15b4a0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,26 +1,27 @@ -import pytest -import nonebot import typing +import nonebot +import pytest + if typing.TYPE_CHECKING: import sys - sys.path.append('./src/plugins') + + sys.path.append("./src/plugins") import nonebot_bison -@pytest.fixture#(scope="module") + +@pytest.fixture # (scope="module") def plugin_module(tmpdir): nonebot.init(bison_config_path=str(tmpdir)) - nonebot.load_plugins('src/plugins') + nonebot.load_plugins("src/plugins") plugins = nonebot.get_loaded_plugins() - plugin = list(filter(lambda x: x.name == 'nonebot_bison', plugins))[0] + plugin = list(filter(lambda x: x.name == "nonebot_bison", plugins))[0] return plugin.module -@pytest.fixture -def dummy_user_subinfo(plugin_module: 'nonebot_bison'): - user = plugin_module.types.User('123', 'group') - return plugin_module.types.UserSubInfo( - user=user, - category_getter=lambda _: [], - tag_getter=lambda _: [] - ) +@pytest.fixture +def dummy_user_subinfo(plugin_module: "nonebot_bison"): + user = plugin_module.types.User("123", "group") + return plugin_module.types.UserSubInfo( + user=user, category_getter=lambda _: [], tag_getter=lambda _: [] + ) diff --git a/tests/platforms/test_arknights.py b/tests/platforms/test_arknights.py index 1058019..7734791 100644 --- a/tests/platforms/test_arknights.py +++ b/tests/platforms/test_arknights.py @@ -1,64 +1,93 @@ -import pytest import typing + +import pytest import respx from httpx import Response if typing.TYPE_CHECKING: import sys - sys.path.append('./src/plugins') + + sys.path.append("./src/plugins") import nonebot_bison -from .utils import get_json, get_file +from .utils import get_file, get_json + @pytest.fixture -def arknights(plugin_module: 'nonebot_bison'): - return plugin_module.platform.platform_manager['arknights'] +def arknights(plugin_module: "nonebot_bison"): + return plugin_module.platform.platform_manager["arknights"] -@pytest.fixture(scope='module') + +@pytest.fixture(scope="module") def arknights_list_0(): - return get_json('arknights_list_0.json') + return get_json("arknights_list_0.json") -@pytest.fixture(scope='module') + +@pytest.fixture(scope="module") def arknights_list_1(): - return get_json('arknights_list_1.json') + return get_json("arknights_list_1.json") -@pytest.fixture(scope='module') + +@pytest.fixture(scope="module") def monster_siren_list_0(): - return get_json('monster-siren_list_0.json') + return get_json("monster-siren_list_0.json") -@pytest.fixture(scope='module') + +@pytest.fixture(scope="module") def monster_siren_list_1(): - return get_json('monster-siren_list_1.json') + return get_json("monster-siren_list_1.json") + @pytest.mark.asyncio @respx.mock -async def test_fetch_new(arknights, dummy_user_subinfo, arknights_list_0, arknights_list_1, monster_siren_list_0, monster_siren_list_1): - ak_list_router = respx.get("https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/announcement.meta.json") - detail_router = respx.get("https://ak.hycdn.cn/announce/IOS/announcement/805_1640074952.html") - version_router = respx.get('https://ak-conf.hypergryph.com/config/prod/official/IOS/version') - preannouncement_router = respx.get('https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json') +async def test_fetch_new( + arknights, + dummy_user_subinfo, + arknights_list_0, + arknights_list_1, + monster_siren_list_0, + monster_siren_list_1, +): + ak_list_router = respx.get( + "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/announcement.meta.json" + ) + detail_router = respx.get( + "https://ak.hycdn.cn/announce/IOS/announcement/805_1640074952.html" + ) + version_router = respx.get( + "https://ak-conf.hypergryph.com/config/prod/official/IOS/version" + ) + preannouncement_router = respx.get( + "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json" + ) monster_siren_router = respx.get("https://monster-siren.hypergryph.com/api/news") ak_list_router.mock(return_value=Response(200, json=arknights_list_0)) - detail_router.mock(return_value=Response(200, text=get_file('arknights-detail-805'))) - version_router.mock(return_value=Response(200, json=get_json('arknights-version-0.json'))) - preannouncement_router.mock(return_value=Response(200, json=get_json('arknights-pre-0.json'))) + detail_router.mock( + return_value=Response(200, text=get_file("arknights-detail-805")) + ) + version_router.mock( + return_value=Response(200, json=get_json("arknights-version-0.json")) + ) + preannouncement_router.mock( + return_value=Response(200, json=get_json("arknights-pre-0.json")) + ) monster_siren_router.mock(return_value=Response(200, json=monster_siren_list_0)) - target = '' + target = "" res = await arknights.fetch_new_post(target, [dummy_user_subinfo]) - assert(ak_list_router.called) - assert(len(res) == 0) - assert(not detail_router.called) + assert ak_list_router.called + assert len(res) == 0 + assert not detail_router.called mock_data = arknights_list_1 ak_list_router.mock(return_value=Response(200, json=mock_data)) res3 = await arknights.fetch_new_post(target, [dummy_user_subinfo]) - assert(len(res3[0][1]) == 1) - assert(detail_router.called) + assert len(res3[0][1]) == 1 + assert detail_router.called post = res3[0][1][0] - assert(post.target_type == 'arknights') - assert(post.text == '') - assert(post.url == '') - assert(post.target_name == '明日方舟游戏内公告') - assert(len(post.pics) == 1) + assert post.target_type == "arknights" + assert post.text == "" + assert post.url == "" + assert post.target_name == "明日方舟游戏内公告" + assert len(post.pics) == 1 # assert(post.pics == ['https://ak-fs.hypergryph.com/announce/images/20210623/e6f49aeb9547a2278678368a43b95b07.jpg']) print(res3[0][1]) r = await post.generate_messages() diff --git a/tests/platforms/test_bilibili.py b/tests/platforms/test_bilibili.py index d04c604..4b99c83 100644 --- a/tests/platforms/test_bilibili.py +++ b/tests/platforms/test_bilibili.py @@ -1,37 +1,53 @@ -import pytest import typing + +import pytest from httpx import Response if typing.TYPE_CHECKING: import sys - sys.path.append('./src/plugins') + + sys.path.append("./src/plugins") import nonebot_bison from .utils import get_json -@pytest.fixture(scope='module') + +@pytest.fixture(scope="module") def bing_dy_list(): - return get_json('bilibili_bing_list.json')['data']['cards'] + return get_json("bilibili_bing_list.json")["data"]["cards"] + @pytest.fixture -def bilibili(plugin_module: 'nonebot_bison'): - return plugin_module.platform.platform_manager['bilibili'] +def bilibili(plugin_module: "nonebot_bison"): + return plugin_module.platform.platform_manager["bilibili"] + @pytest.mark.asyncio async def test_video_forward(bilibili, bing_dy_list): - post = await bilibili.parse(bing_dy_list[1]) - assert(post.text == '答案揭晓:宿舍!来看看投票结果\nhttps://t.bilibili.com/568093580488553786\n--------------\n#可露希尔的秘密档案# \n11:来宿舍休息一下吧 \n档案来源:lambda:\\罗德岛内务\\秘密档案 \n发布时间:9/12 1:00 P.M. \n档案类型:可见 \n档案描述:今天请了病假在宿舍休息。很舒适。 \n提供者:赫默') + post = await bilibili.parse(bing_dy_list[1]) + assert ( + post.text + == "答案揭晓:宿舍!来看看投票结果\nhttps://t.bilibili.com/568093580488553786\n--------------\n#可露希尔的秘密档案# \n11:来宿舍休息一下吧 \n档案来源:lambda:\\罗德岛内务\\秘密档案 \n发布时间:9/12 1:00 P.M. \n档案类型:可见 \n档案描述:今天请了病假在宿舍休息。很舒适。 \n提供者:赫默" + ) + @pytest.mark.asyncio async def test_article_forward(bilibili, bing_dy_list): - post = await bilibili.parse(bing_dy_list[4]) - assert(post.text == '#明日方舟##饼学大厦#\n9.11专栏更新完毕,这还塌了实属没跟新运营对上\n后边除了周日发饼和PV没提及的中文语音,稳了\n别忘了来参加#可露希尔的秘密档案#的主题投票\nhttps://t.bilibili.com/568093580488553786?tab=2' + - '\n--------------\n' + - '【明日方舟】饼学大厦#12~14(风暴瞭望&玛莉娅·临光&红松林&感谢庆典)9.11更新 更新记录09.11更新:覆盖09.10更新;以及排期更新,猜测周一周五开活动09.10更新:以周五开活动为底,PV/公告调整位置,整体结构更新09.08更新:饼学大厦#12更新,新增一件六星商店服饰(周日发饼)09.06更新:饼学大厦整栋整栋翻新,改为9.16开主线(四日无饼!)09.05凌晨更新:10.13后的排期(两日无饼,鹰角背刺,心狠手辣)前言感谢楪筱祈ぺ的动态-哔哩哔哩 (bilibili.com) 对饼学的贡献!后续排期:9.17【风暴瞭望】、10.01【玛莉娅·临光】复刻、10.1') + post = await bilibili.parse(bing_dy_list[4]) + assert ( + post.text + == "#明日方舟##饼学大厦#\n9.11专栏更新完毕,这还塌了实属没跟新运营对上\n后边除了周日发饼和PV没提及的中文语音,稳了\n别忘了来参加#可露希尔的秘密档案#的主题投票\nhttps://t.bilibili.com/568093580488553786?tab=2" + + "\n--------------\n" + + "【明日方舟】饼学大厦#12~14(风暴瞭望&玛莉娅·临光&红松林&感谢庆典)9.11更新 更新记录09.11更新:覆盖09.10更新;以及排期更新,猜测周一周五开活动09.10更新:以周五开活动为底,PV/公告调整位置,整体结构更新09.08更新:饼学大厦#12更新,新增一件六星商店服饰(周日发饼)09.06更新:饼学大厦整栋整栋翻新,改为9.16开主线(四日无饼!)09.05凌晨更新:10.13后的排期(两日无饼,鹰角背刺,心狠手辣)前言感谢楪筱祈ぺ的动态-哔哩哔哩 (bilibili.com) 对饼学的贡献!后续排期:9.17【风暴瞭望】、10.01【玛莉娅·临光】复刻、10.1" + ) + @pytest.mark.asyncio async def test_dynamic_forward(bilibili, bing_dy_list): post = await bilibili.parse(bing_dy_list[5]) - assert(post.text == '饼组主线饼学预测——9.11版\n①今日结果\n9.11 殿堂上的游禽-星极(x,新运营实锤了)\n②后续预测\n9.12 #罗德岛相簿#+#可露希尔的秘密档案#11话\n9.13 六星先锋(执旗手)干员-琴柳\n9.14 宣传策略-空弦+家具\n9.15 轮换池(+中文语音前瞻)\n9.16 停机\n9.17 #罗德岛闲逛部#+新六星EP+EP09·风暴瞭望开启\n9.19 #罗德岛相簿#' + - '\n--------------\n' + - '#明日方舟#\n【新增服饰】\n//殿堂上的游禽 - 星极\n塞壬唱片偶像企划《闪耀阶梯》特供服饰/殿堂上的游禽。星极自费参加了这项企划,尝试着用大众能接受的方式演绎天空之上的故事。\n\n_____________\n谦逊留给观众,骄傲发自歌喉,此夜,唯我璀璨。 ') + assert ( + post.text + == "饼组主线饼学预测——9.11版\n①今日结果\n9.11 殿堂上的游禽-星极(x,新运营实锤了)\n②后续预测\n9.12 #罗德岛相簿#+#可露希尔的秘密档案#11话\n9.13 六星先锋(执旗手)干员-琴柳\n9.14 宣传策略-空弦+家具\n9.15 轮换池(+中文语音前瞻)\n9.16 停机\n9.17 #罗德岛闲逛部#+新六星EP+EP09·风暴瞭望开启\n9.19 #罗德岛相簿#" + + "\n--------------\n" + + "#明日方舟#\n【新增服饰】\n//殿堂上的游禽 - 星极\n塞壬唱片偶像企划《闪耀阶梯》特供服饰/殿堂上的游禽。星极自费参加了这项企划,尝试着用大众能接受的方式演绎天空之上的故事。\n\n_____________\n谦逊留给观众,骄傲发自歌喉,此夜,唯我璀璨。 " + ) diff --git a/tests/platforms/test_ncm_artist.py b/tests/platforms/test_ncm_artist.py index f5ed1a2..37242e5 100644 --- a/tests/platforms/test_ncm_artist.py +++ b/tests/platforms/test_ncm_artist.py @@ -1,49 +1,54 @@ -from .utils import get_json +import time +import typing + import pytest import respx -import typing -import time from httpx import Response +from .utils import get_json + if typing.TYPE_CHECKING: import sys - sys.path.append('./src/plugins') + + sys.path.append("./src/plugins") import nonebot_bison + @pytest.fixture -def ncm_artist(plugin_module: 'nonebot_bison'): - return plugin_module.platform.platform_manager['ncm-artist'] +def ncm_artist(plugin_module: "nonebot_bison"): + return plugin_module.platform.platform_manager["ncm-artist"] -@pytest.fixture(scope='module') + +@pytest.fixture(scope="module") def ncm_artist_raw(): - return get_json('ncm_siren.json') + return get_json("ncm_siren.json") -@pytest.fixture(scope='module') + +@pytest.fixture(scope="module") def ncm_artist_0(ncm_artist_raw): - return { - **ncm_artist_raw, - 'hotAlbums': ncm_artist_raw['hotAlbums'][1:] - } + return {**ncm_artist_raw, "hotAlbums": ncm_artist_raw["hotAlbums"][1:]} -@pytest.fixture(scope='module') + +@pytest.fixture(scope="module") def ncm_artist_1(ncm_artist_raw: dict): res = ncm_artist_raw.copy() - res['hotAlbums'] = ncm_artist_raw['hotAlbums'][:] - res['hotAlbums'][0]['publishTime'] = int(time.time() * 1000) - return res + res["hotAlbums"] = ncm_artist_raw["hotAlbums"][:] + res["hotAlbums"][0]["publishTime"] = int(time.time() * 1000) + return res + @pytest.mark.asyncio @respx.mock async def test_fetch_new(ncm_artist, ncm_artist_0, ncm_artist_1, dummy_user_subinfo): - ncm_router = respx.get("https://music.163.com/api/artist/albums/32540734") + ncm_router = respx.get("https://music.163.com/api/artist/albums/32540734") ncm_router.mock(return_value=Response(200, json=ncm_artist_0)) - target = '32540734' + target = "32540734" res = await ncm_artist.fetch_new_post(target, [dummy_user_subinfo]) - assert(ncm_router.called) - assert(len(res) == 0) + assert ncm_router.called + assert len(res) == 0 ncm_router.mock(return_value=Response(200, json=ncm_artist_1)) res2 = await ncm_artist.fetch_new_post(target, [dummy_user_subinfo]) post = res2[0][1][0] - assert(post.target_type == 'ncm-artist') - assert(post.text == '新专辑发布:Y1K') - assert(post.url == 'https://music.163.com/#/album?id=131074504') + assert post.target_type == "ncm-artist" + assert post.text == "新专辑发布:Y1K" + assert post.url == "https://music.163.com/#/album?id=131074504" diff --git a/tests/platforms/test_ncm_radio.py b/tests/platforms/test_ncm_radio.py index d248b0f..3ce6cfe 100644 --- a/tests/platforms/test_ncm_radio.py +++ b/tests/platforms/test_ncm_radio.py @@ -1,53 +1,59 @@ +import time +import typing -from .utils import get_json import pytest import respx -import typing -import time from httpx import Response +from .utils import get_json + if typing.TYPE_CHECKING: import sys - sys.path.append('./src/plugins') + + sys.path.append("./src/plugins") import nonebot_bison + @pytest.fixture -def ncm_radio(plugin_module: 'nonebot_bison'): - return plugin_module.platform.platform_manager['ncm-radio'] +def ncm_radio(plugin_module: "nonebot_bison"): + return plugin_module.platform.platform_manager["ncm-radio"] -@pytest.fixture(scope='module') + +@pytest.fixture(scope="module") def ncm_radio_raw(): - return get_json('ncm_radio_ark.json') + return get_json("ncm_radio_ark.json") -@pytest.fixture(scope='module') + +@pytest.fixture(scope="module") def ncm_radio_0(ncm_radio_raw): - return { - **ncm_radio_raw, - 'programs': ncm_radio_raw['programs'][1:] - } + return {**ncm_radio_raw, "programs": ncm_radio_raw["programs"][1:]} -@pytest.fixture(scope='module') + +@pytest.fixture(scope="module") def ncm_radio_1(ncm_radio_raw: dict): res = ncm_radio_raw.copy() - res['programs'] = ncm_radio_raw['programs'][:] - res['programs'][0]['createTime'] = int(time.time() * 1000) - return res + res["programs"] = ncm_radio_raw["programs"][:] + res["programs"][0]["createTime"] = int(time.time() * 1000) + return res + @pytest.mark.asyncio @respx.mock async def test_fetch_new(ncm_radio, ncm_radio_0, ncm_radio_1, dummy_user_subinfo): - ncm_router = respx.post("http://music.163.com/api/dj/program/byradio") + ncm_router = respx.post("http://music.163.com/api/dj/program/byradio") ncm_router.mock(return_value=Response(200, json=ncm_radio_0)) - target = '793745436' + target = "793745436" res = await ncm_radio.fetch_new_post(target, [dummy_user_subinfo]) - assert(ncm_router.called) - assert(len(res) == 0) + assert ncm_router.called + assert len(res) == 0 ncm_router.mock(return_value=Response(200, json=ncm_radio_1)) res2 = await ncm_radio.fetch_new_post(target, [dummy_user_subinfo]) post = res2[0][1][0] print(post) - assert(post.target_type == 'ncm-radio') - assert(post.text == '网易云电台更新:「松烟行动」灰齐山麓') - assert(post.url == 'https://music.163.com/#/program/2494997688') - assert(post.pics == ['http://p1.music.126.net/H5em5xUNIYXcjJhOmeaSqQ==/109951166647436789.jpg']) - assert(post.target_name == '《明日方舟》游戏原声OST') + assert post.target_type == "ncm-radio" + assert post.text == "网易云电台更新:「松烟行动」灰齐山麓" + assert post.url == "https://music.163.com/#/program/2494997688" + assert post.pics == [ + "http://p1.music.126.net/H5em5xUNIYXcjJhOmeaSqQ==/109951166647436789.jpg" + ] + assert post.target_name == "《明日方舟》游戏原声OST" diff --git a/tests/platforms/test_platform.py b/tests/platforms/test_platform.py index f708863..ae55b38 100644 --- a/tests/platforms/test_platform.py +++ b/tests/platforms/test_platform.py @@ -6,42 +6,48 @@ import pytest if typing.TYPE_CHECKING: import sys - sys.path.append('./src/plugins') + + sys.path.append("./src/plugins") import nonebot_bison from nonebot_bison.types import * from nonebot_bison.post import Post from time import time + now = time() passed = now - 3 * 60 * 60 raw_post_list_1 = [ - {'id': 1, 'text': 'p1', 'date': now, 'tags': ['tag1'], 'category': 1} - ] + {"id": 1, "text": "p1", "date": now, "tags": ["tag1"], "category": 1} +] raw_post_list_2 = raw_post_list_1 + [ - {'id': 2, 'text': 'p2', 'date': now, 'tags': ['tag1'], 'category': 1}, - {'id': 3, 'text': 'p3', 'date': now, 'tags': ['tag2'], 'category': 2}, - {'id': 4, 'text': 'p4', 'date': now, 'tags': ['tag2'], 'category': 3} - ] + {"id": 2, "text": "p2", "date": now, "tags": ["tag1"], "category": 1}, + {"id": 3, "text": "p3", "date": now, "tags": ["tag2"], "category": 2}, + {"id": 4, "text": "p4", "date": now, "tags": ["tag2"], "category": 3}, +] + @pytest.fixture -def dummy_user(plugin_module: 'nonebot_bison'): - user = plugin_module.types.User('123', 'group') +def dummy_user(plugin_module: "nonebot_bison"): + user = plugin_module.types.User("123", "group") return user + @pytest.fixture -def user_info_factory(plugin_module: 'nonebot_bison', dummy_user): +def user_info_factory(plugin_module: "nonebot_bison", dummy_user): def _user_info(category_getter, tag_getter): return plugin_module.types.UserSubInfo(dummy_user, category_getter, tag_getter) + return _user_info + @pytest.fixture -def mock_platform_without_cats_tags(plugin_module: 'nonebot_bison'): +def mock_platform_without_cats_tags(plugin_module: "nonebot_bison"): class MockPlatform(plugin_module.platform.platform.NewMessage): - platform_name = 'mock_platform' - name = 'Mock Platform' + platform_name = "mock_platform" + name = "Mock Platform" enabled = True is_common = True schedule_interval = 10 @@ -52,21 +58,26 @@ def mock_platform_without_cats_tags(plugin_module: 'nonebot_bison'): def __init__(self): self.sub_index = 0 super().__init__() - + @staticmethod - async def get_target_name(_: 'Target'): - return 'MockPlatform' + async def get_target_name(_: "Target"): + return "MockPlatform" - def get_id(self, post: 'RawPost') -> Any: - return post['id'] + def get_id(self, post: "RawPost") -> Any: + return post["id"] - def get_date(self, raw_post: 'RawPost') -> float: - return raw_post['date'] + def get_date(self, raw_post: "RawPost") -> float: + return raw_post["date"] - async def parse(self, raw_post: 'RawPost') -> 'Post': - return plugin_module.post.Post('mock_platform', raw_post['text'], 'http://t.tt/' + str(self.get_id(raw_post)), target_name='Mock') + async def parse(self, raw_post: "RawPost") -> "Post": + return plugin_module.post.Post( + "mock_platform", + raw_post["text"], + "http://t.tt/" + str(self.get_id(raw_post)), + target_name="Mock", + ) - async def get_sub_list(self, _: 'Target'): + async def get_sub_list(self, _: "Target"): if self.sub_index == 0: self.sub_index += 1 return raw_post_list_1 @@ -75,45 +86,52 @@ def mock_platform_without_cats_tags(plugin_module: 'nonebot_bison'): return MockPlatform() + @pytest.fixture -def mock_platform(plugin_module: 'nonebot_bison'): +def mock_platform(plugin_module: "nonebot_bison"): class MockPlatform(plugin_module.platform.platform.NewMessage): - platform_name = 'mock_platform' - name = 'Mock Platform' + platform_name = "mock_platform" + name = "Mock Platform" enabled = True is_common = True schedule_interval = 10 enable_tag = True has_target = True categories = { - 1: '转发', - 2: '视频', - } + 1: "转发", + 2: "视频", + } + def __init__(self): self.sub_index = 0 super().__init__() - + @staticmethod - async def get_target_name(_: 'Target'): - return 'MockPlatform' + async def get_target_name(_: "Target"): + return "MockPlatform" - def get_id(self, post: 'RawPost') -> Any: - return post['id'] + def get_id(self, post: "RawPost") -> Any: + return post["id"] - def get_date(self, raw_post: 'RawPost') -> float: - return raw_post['date'] + def get_date(self, raw_post: "RawPost") -> float: + return raw_post["date"] - def get_tags(self, raw_post: 'RawPost') -> list['Tag']: - return raw_post['tags'] + def get_tags(self, raw_post: "RawPost") -> list["Tag"]: + return raw_post["tags"] - def get_category(self, raw_post: 'RawPost') -> 'Category': - return raw_post['category'] + def get_category(self, raw_post: "RawPost") -> "Category": + return raw_post["category"] - async def parse(self, raw_post: 'RawPost') -> 'Post': - return plugin_module.post.Post('mock_platform', raw_post['text'], 'http://t.tt/' + str(self.get_id(raw_post)), target_name='Mock') + async def parse(self, raw_post: "RawPost") -> "Post": + return plugin_module.post.Post( + "mock_platform", + raw_post["text"], + "http://t.tt/" + str(self.get_id(raw_post)), + target_name="Mock", + ) - async def get_sub_list(self, _: 'Target'): + async def get_sub_list(self, _: "Target"): if self.sub_index == 0: self.sub_index += 1 return raw_post_list_1 @@ -122,49 +140,52 @@ def mock_platform(plugin_module: 'nonebot_bison'): return MockPlatform() + @pytest.fixture -def mock_platform_no_target(plugin_module: 'nonebot_bison'): +def mock_platform_no_target(plugin_module: "nonebot_bison"): class MockPlatform(plugin_module.platform.platform.NewMessage): - platform_name = 'mock_platform' - name = 'Mock Platform' + platform_name = "mock_platform" + name = "Mock Platform" enabled = True is_common = True - schedule_type = 'interval' - schedule_kw = {'seconds': 30} + schedule_type = "interval" + schedule_kw = {"seconds": 30} enable_tag = True has_target = False - categories = { - 1: '转发', - 2: '视频', - 3: '不支持' - } + categories = {1: "转发", 2: "视频", 3: "不支持"} + def __init__(self): self.sub_index = 0 super().__init__() - + @staticmethod - async def get_target_name(_: 'Target'): - return 'MockPlatform' + async def get_target_name(_: "Target"): + return "MockPlatform" - def get_id(self, post: 'RawPost') -> Any: - return post['id'] + def get_id(self, post: "RawPost") -> Any: + return post["id"] - def get_date(self, raw_post: 'RawPost') -> float: - return raw_post['date'] + def get_date(self, raw_post: "RawPost") -> float: + return raw_post["date"] - def get_tags(self, raw_post: 'RawPost') -> list['Tag']: - return raw_post['tags'] + def get_tags(self, raw_post: "RawPost") -> list["Tag"]: + return raw_post["tags"] - def get_category(self, raw_post: 'RawPost') -> 'Category': - if raw_post['category'] == 3: + def get_category(self, raw_post: "RawPost") -> "Category": + if raw_post["category"] == 3: raise plugin_module.platform.platform.CategoryNotSupport() - return raw_post['category'] + return raw_post["category"] - async def parse(self, raw_post: 'RawPost') -> 'Post': - return plugin_module.post.Post('mock_platform', raw_post['text'], 'http://t.tt/' + str(self.get_id(raw_post)), target_name='Mock') + async def parse(self, raw_post: "RawPost") -> "Post": + return plugin_module.post.Post( + "mock_platform", + raw_post["text"], + "http://t.tt/" + str(self.get_id(raw_post)), + target_name="Mock", + ) - async def get_sub_list(self, _: 'Target'): + async def get_sub_list(self, _: "Target"): if self.sub_index == 0: self.sub_index += 1 return raw_post_list_1 @@ -173,54 +194,61 @@ def mock_platform_no_target(plugin_module: 'nonebot_bison'): return MockPlatform() + @pytest.fixture -def mock_platform_no_target_2(plugin_module: 'nonebot_bison'): +def mock_platform_no_target_2(plugin_module: "nonebot_bison"): class MockPlatform(plugin_module.platform.platform.NewMessage): - platform_name = 'mock_platform' - name = 'Mock Platform' + platform_name = "mock_platform" + name = "Mock Platform" enabled = True - schedule_type = 'interval' - schedule_kw = {'seconds': 30} + schedule_type = "interval" + schedule_kw = {"seconds": 30} is_common = True enable_tag = True has_target = False categories = { - 4: 'leixing4', - 5: 'leixing5', - } + 4: "leixing4", + 5: "leixing5", + } + def __init__(self): self.sub_index = 0 super().__init__() - + @staticmethod - async def get_target_name(_: 'Target'): - return 'MockPlatform' + async def get_target_name(_: "Target"): + return "MockPlatform" - def get_id(self, post: 'RawPost') -> Any: - return post['id'] + def get_id(self, post: "RawPost") -> Any: + return post["id"] - def get_date(self, raw_post: 'RawPost') -> float: - return raw_post['date'] + def get_date(self, raw_post: "RawPost") -> float: + return raw_post["date"] - def get_tags(self, raw_post: 'RawPost') -> list['Tag']: - return raw_post['tags'] + def get_tags(self, raw_post: "RawPost") -> list["Tag"]: + return raw_post["tags"] - def get_category(self, raw_post: 'RawPost') -> 'Category': - return raw_post['category'] + def get_category(self, raw_post: "RawPost") -> "Category": + return raw_post["category"] - async def parse(self, raw_post: 'RawPost') -> 'Post': - return plugin_module.post.Post('mock_platform_2', raw_post['text'], 'http://t.tt/' + str(self.get_id(raw_post)), target_name='Mock') + async def parse(self, raw_post: "RawPost") -> "Post": + return plugin_module.post.Post( + "mock_platform_2", + raw_post["text"], + "http://t.tt/" + str(self.get_id(raw_post)), + target_name="Mock", + ) - async def get_sub_list(self, _: 'Target'): + async def get_sub_list(self, _: "Target"): list_1 = [ - {'id': 5, 'text': 'p5', 'date': now, 'tags': ['tag1'], 'category': 4} - ] + {"id": 5, "text": "p5", "date": now, "tags": ["tag1"], "category": 4} + ] list_2 = list_1 + [ - {'id': 6, 'text': 'p6', 'date': now, 'tags': ['tag1'], 'category': 4}, - {'id': 7, 'text': 'p7', 'date': now, 'tags': ['tag2'], 'category': 5}, - ] + {"id": 6, "text": "p6", "date": now, "tags": ["tag1"], "category": 4}, + {"id": 7, "text": "p7", "date": now, "tags": ["tag2"], "category": 5}, + ] if self.sub_index == 0: self.sub_index += 1 return list_1 @@ -229,145 +257,190 @@ def mock_platform_no_target_2(plugin_module: 'nonebot_bison'): return MockPlatform() + @pytest.fixture -def mock_status_change(plugin_module: 'nonebot_bison'): +def mock_status_change(plugin_module: "nonebot_bison"): class MockPlatform(plugin_module.platform.platform.StatusChange): - platform_name = 'mock_platform' - name = 'Mock Platform' + platform_name = "mock_platform" + name = "Mock Platform" enabled = True is_common = True enable_tag = False - schedule_type = 'interval' - schedule_kw = {'seconds': 10} + schedule_type = "interval" + schedule_kw = {"seconds": 10} has_target = False categories = { - 1: '转发', - 2: '视频', - } + 1: "转发", + 2: "视频", + } + def __init__(self): self.sub_index = 0 super().__init__() - async def get_status(self, _: 'Target'): + async def get_status(self, _: "Target"): if self.sub_index == 0: self.sub_index += 1 - return {'s': False} + return {"s": False} elif self.sub_index == 1: self.sub_index += 1 - return {'s': True} + return {"s": True} else: - return {'s': False} + return {"s": False} - def compare_status(self, target, old_status, new_status) -> list['RawPost']: - if old_status['s'] == False and new_status['s'] == True: - return [{'text': 'on', 'cat': 1}] - elif old_status['s'] == True and new_status['s'] == False: - return [{'text': 'off', 'cat': 2}] + def compare_status(self, target, old_status, new_status) -> list["RawPost"]: + if old_status["s"] == False and new_status["s"] == True: + return [{"text": "on", "cat": 1}] + elif old_status["s"] == True and new_status["s"] == False: + return [{"text": "off", "cat": 2}] return [] - async def parse(self, raw_post) -> 'Post': - return plugin_module.post.Post('mock_status', raw_post['text'], '') + async def parse(self, raw_post) -> "Post": + return plugin_module.post.Post("mock_status", raw_post["text"], "") def get_category(self, raw_post): - return raw_post['cat'] + return raw_post["cat"] return MockPlatform() @pytest.mark.asyncio -async def test_new_message_target_without_cats_tags(mock_platform_without_cats_tags, user_info_factory): - res1 = await mock_platform_without_cats_tags.fetch_new_post('dummy', [user_info_factory(lambda _: [1,2], lambda _: [])]) - assert(len(res1) == 0) - res2 = await mock_platform_without_cats_tags.fetch_new_post('dummy', [ - user_info_factory(lambda _: [], lambda _: []), - ]) - assert(len(res2) == 1) +async def test_new_message_target_without_cats_tags( + mock_platform_without_cats_tags, user_info_factory +): + res1 = await mock_platform_without_cats_tags.fetch_new_post( + "dummy", [user_info_factory(lambda _: [1, 2], lambda _: [])] + ) + assert len(res1) == 0 + res2 = await mock_platform_without_cats_tags.fetch_new_post( + "dummy", + [ + user_info_factory(lambda _: [], lambda _: []), + ], + ) + assert len(res2) == 1 posts_1 = res2[0][1] - assert(len(posts_1) == 3) + assert len(posts_1) == 3 id_set_1 = set(map(lambda x: x.text, posts_1)) - assert('p2' in id_set_1 and 'p3' in id_set_1 and 'p4' in id_set_1) + assert "p2" in id_set_1 and "p3" in id_set_1 and "p4" in id_set_1 + @pytest.mark.asyncio async def test_new_message_target(mock_platform, user_info_factory): - res1 = await mock_platform.fetch_new_post('dummy', [user_info_factory(lambda _: [1,2], lambda _: [])]) - assert(len(res1) == 0) - res2 = await mock_platform.fetch_new_post('dummy', [ - user_info_factory(lambda _: [1,2], lambda _: []), - user_info_factory(lambda _: [1], lambda _: []), - user_info_factory(lambda _: [1,2], lambda _: ['tag1']) - ]) - assert(len(res2) == 3) + res1 = await mock_platform.fetch_new_post( + "dummy", [user_info_factory(lambda _: [1, 2], lambda _: [])] + ) + assert len(res1) == 0 + res2 = await mock_platform.fetch_new_post( + "dummy", + [ + user_info_factory(lambda _: [1, 2], lambda _: []), + user_info_factory(lambda _: [1], lambda _: []), + user_info_factory(lambda _: [1, 2], lambda _: ["tag1"]), + ], + ) + assert len(res2) == 3 posts_1 = res2[0][1] posts_2 = res2[1][1] posts_3 = res2[2][1] - assert(len(posts_1) == 2) - assert(len(posts_2) == 1) - assert(len(posts_3) == 1) + assert len(posts_1) == 2 + assert len(posts_2) == 1 + assert len(posts_3) == 1 id_set_1 = set(map(lambda x: x.text, posts_1)) id_set_2 = set(map(lambda x: x.text, posts_2)) id_set_3 = set(map(lambda x: x.text, posts_3)) - assert('p2' in id_set_1 and 'p3' in id_set_1) - assert('p2' in id_set_2) - assert('p2' in id_set_3) + assert "p2" in id_set_1 and "p3" in id_set_1 + assert "p2" in id_set_2 + assert "p2" in id_set_3 + @pytest.mark.asyncio async def test_new_message_no_target(mock_platform_no_target, user_info_factory): - res1 = await mock_platform_no_target.fetch_new_post('dummy', [user_info_factory(lambda _: [1,2], lambda _: [])]) - assert(len(res1) == 0) - res2 = await mock_platform_no_target.fetch_new_post('dummy', [ - user_info_factory(lambda _: [1,2], lambda _: []), - user_info_factory(lambda _: [1], lambda _: []), - user_info_factory(lambda _: [1,2], lambda _: ['tag1']) - ]) - assert(len(res2) == 3) + res1 = await mock_platform_no_target.fetch_new_post( + "dummy", [user_info_factory(lambda _: [1, 2], lambda _: [])] + ) + assert len(res1) == 0 + res2 = await mock_platform_no_target.fetch_new_post( + "dummy", + [ + user_info_factory(lambda _: [1, 2], lambda _: []), + user_info_factory(lambda _: [1], lambda _: []), + user_info_factory(lambda _: [1, 2], lambda _: ["tag1"]), + ], + ) + assert len(res2) == 3 posts_1 = res2[0][1] posts_2 = res2[1][1] posts_3 = res2[2][1] - assert(len(posts_1) == 2) - assert(len(posts_2) == 1) - assert(len(posts_3) == 1) + assert len(posts_1) == 2 + assert len(posts_2) == 1 + assert len(posts_3) == 1 id_set_1 = set(map(lambda x: x.text, posts_1)) id_set_2 = set(map(lambda x: x.text, posts_2)) id_set_3 = set(map(lambda x: x.text, posts_3)) - assert('p2' in id_set_1 and 'p3' in id_set_1) - assert('p2' in id_set_2) - assert('p2' in id_set_3) - res3 = await mock_platform_no_target.fetch_new_post('dummy', [user_info_factory(lambda _: [1,2], lambda _: [])]) - assert(len(res3) == 0) + assert "p2" in id_set_1 and "p3" in id_set_1 + assert "p2" in id_set_2 + assert "p2" in id_set_3 + res3 = await mock_platform_no_target.fetch_new_post( + "dummy", [user_info_factory(lambda _: [1, 2], lambda _: [])] + ) + assert len(res3) == 0 + @pytest.mark.asyncio async def test_status_change(mock_status_change, user_info_factory): - res1 = await mock_status_change.fetch_new_post('dummy', [user_info_factory(lambda _: [1,2], lambda _: [])]) - assert(len(res1) == 0) - res2 = await mock_status_change.fetch_new_post('dummy', [ - user_info_factory(lambda _: [1,2], lambda _:[]) - ]) - assert(len(res2) == 1) + res1 = await mock_status_change.fetch_new_post( + "dummy", [user_info_factory(lambda _: [1, 2], lambda _: [])] + ) + assert len(res1) == 0 + res2 = await mock_status_change.fetch_new_post( + "dummy", [user_info_factory(lambda _: [1, 2], lambda _: [])] + ) + assert len(res2) == 1 posts = res2[0][1] - assert(len(posts) == 1) - assert(posts[0].text == 'on') - res3 = await mock_status_change.fetch_new_post('dummy', [ - user_info_factory(lambda _: [1,2], lambda _: []), - user_info_factory(lambda _: [1], lambda _: []), - ]) - assert(len(res3) == 2) - assert(len(res3[0][1]) == 1) - assert(res3[0][1][0].text == 'off') - assert(len(res3[1][1]) == 0) - res4 = await mock_status_change.fetch_new_post('dummy', [user_info_factory(lambda _: [1,2], lambda _: [])]) - assert(len(res4) == 0) + assert len(posts) == 1 + assert posts[0].text == "on" + res3 = await mock_status_change.fetch_new_post( + "dummy", + [ + user_info_factory(lambda _: [1, 2], lambda _: []), + user_info_factory(lambda _: [1], lambda _: []), + ], + ) + assert len(res3) == 2 + assert len(res3[0][1]) == 1 + assert res3[0][1][0].text == "off" + assert len(res3[1][1]) == 0 + res4 = await mock_status_change.fetch_new_post( + "dummy", [user_info_factory(lambda _: [1, 2], lambda _: [])] + ) + assert len(res4) == 0 + @pytest.mark.asyncio -async def test_group(plugin_module: 'nonebot_bison', mock_platform_no_target, mock_platform_no_target_2, user_info_factory): - group_platform = plugin_module.platform.platform.NoTargetGroup([mock_platform_no_target, mock_platform_no_target_2]) - res1 = await group_platform.fetch_new_post('dummy', [user_info_factory(lambda _: [1,4], lambda _: [])]) - assert(len(res1) == 0) - res2 = await group_platform.fetch_new_post('dummy', [user_info_factory(lambda _: [1,4], lambda _: [])]) - assert(len(res2) == 1) +async def test_group( + plugin_module: "nonebot_bison", + mock_platform_no_target, + mock_platform_no_target_2, + user_info_factory, +): + group_platform = plugin_module.platform.platform.NoTargetGroup( + [mock_platform_no_target, mock_platform_no_target_2] + ) + res1 = await group_platform.fetch_new_post( + "dummy", [user_info_factory(lambda _: [1, 4], lambda _: [])] + ) + assert len(res1) == 0 + res2 = await group_platform.fetch_new_post( + "dummy", [user_info_factory(lambda _: [1, 4], lambda _: [])] + ) + assert len(res2) == 1 posts = res2[0][1] - assert(len(posts) == 2) + assert len(posts) == 2 id_set_2 = set(map(lambda x: x.text, posts)) - assert('p2' in id_set_2 and 'p6' in id_set_2) - res3 = await group_platform.fetch_new_post('dummy', [user_info_factory(lambda _: [1,4], lambda _: [])]) - assert(len(res3) == 0) + assert "p2" in id_set_2 and "p6" in id_set_2 + res3 = await group_platform.fetch_new_post( + "dummy", [user_info_factory(lambda _: [1, 4], lambda _: [])] + ) + assert len(res3) == 0 diff --git a/tests/platforms/test_weibo.py b/tests/platforms/test_weibo.py index 5fa9c00..68ee873 100644 --- a/tests/platforms/test_weibo.py +++ b/tests/platforms/test_weibo.py @@ -1,110 +1,131 @@ -import pytest import typing -import respx from datetime import datetime -from pytz import timezone -from httpx import Response + import feedparser +import pytest +import respx +from httpx import Response +from pytz import timezone if typing.TYPE_CHECKING: import sys - sys.path.append('./src/plugins') + + sys.path.append("./src/plugins") import nonebot_bison -from .utils import get_json, get_file +from .utils import get_file, get_json + @pytest.fixture -def weibo(plugin_module: 'nonebot_bison'): - return plugin_module.platform.platform_manager['weibo'] +def weibo(plugin_module: "nonebot_bison"): + return plugin_module.platform.platform_manager["weibo"] -@pytest.fixture(scope='module') + +@pytest.fixture(scope="module") def weibo_ak_list_1(): - return get_json('weibo_ak_list_1.json') + return get_json("weibo_ak_list_1.json") + @pytest.mark.asyncio async def test_get_name(weibo): - name = await weibo.get_target_name('6279793937') - assert(name == "明日方舟Arknights") + name = await weibo.get_target_name("6279793937") + assert name == "明日方舟Arknights" + @pytest.mark.asyncio @respx.mock async def test_fetch_new(weibo, dummy_user_subinfo): - ak_list_router = respx.get("https://m.weibo.cn/api/container/getIndex?containerid=1076036279793937") + ak_list_router = respx.get( + "https://m.weibo.cn/api/container/getIndex?containerid=1076036279793937" + ) detail_router = respx.get("https://m.weibo.cn/detail/4649031014551911") - ak_list_router.mock(return_value=Response(200, json=get_json('weibo_ak_list_0.json'))) - detail_router.mock(return_value=Response(200, text=get_file('weibo_detail_4649031014551911'))) - target = '6279793937' + ak_list_router.mock( + return_value=Response(200, json=get_json("weibo_ak_list_0.json")) + ) + detail_router.mock( + return_value=Response(200, text=get_file("weibo_detail_4649031014551911")) + ) + target = "6279793937" res = await weibo.fetch_new_post(target, [dummy_user_subinfo]) - assert(ak_list_router.called) - assert(len(res) == 0) - assert(not detail_router.called) - mock_data = get_json('weibo_ak_list_1.json') + assert ak_list_router.called + assert len(res) == 0 + assert not detail_router.called + mock_data = get_json("weibo_ak_list_1.json") ak_list_router.mock(return_value=Response(200, json=mock_data)) # import ipdb; ipdb.set_trace() res2 = await weibo.fetch_new_post(target, [dummy_user_subinfo]) - assert(len(res2) == 0) - mock_data['data']['cards'][1]['mblog']['created_at'] = \ - datetime.now(timezone('Asia/Shanghai')).strftime('%a %b %d %H:%M:%S %z %Y') + assert len(res2) == 0 + mock_data["data"]["cards"][1]["mblog"]["created_at"] = datetime.now( + timezone("Asia/Shanghai") + ).strftime("%a %b %d %H:%M:%S %z %Y") ak_list_router.mock(return_value=Response(200, json=mock_data)) res3 = await weibo.fetch_new_post(target, [dummy_user_subinfo]) - assert(len(res3[0][1]) == 1) - assert(not detail_router.called) + assert len(res3[0][1]) == 1 + assert not detail_router.called post = res3[0][1][0] - assert(post.target_type == 'weibo') - assert(post.text == '#明日方舟#\nSideStory「沃伦姆德的薄暮」复刻现已开启! ') - assert(post.url == 'https://weibo.com/6279793937/KkBtUx2dv') - assert(post.target_name == '明日方舟Arknights') - assert(len(post.pics) == 1) + assert post.target_type == "weibo" + assert post.text == "#明日方舟#\nSideStory「沃伦姆德的薄暮」复刻现已开启! " + assert post.url == "https://weibo.com/6279793937/KkBtUx2dv" + assert post.target_name == "明日方舟Arknights" + assert len(post.pics) == 1 + @pytest.mark.asyncio async def test_classification(weibo): - mock_data = get_json('weibo_ak_list_1.json') - tuwen = mock_data['data']['cards'][1] - retweet = mock_data['data']['cards'][3] - video = mock_data['data']['cards'][0] - mock_data_ys = get_json('weibo_ys_list_0.json') - text = mock_data_ys['data']['cards'][2] - assert(weibo.get_category(retweet) == 1) - assert(weibo.get_category(video) == 2) - assert(weibo.get_category(tuwen) == 3) - assert(weibo.get_category(text) == 4) + mock_data = get_json("weibo_ak_list_1.json") + tuwen = mock_data["data"]["cards"][1] + retweet = mock_data["data"]["cards"][3] + video = mock_data["data"]["cards"][0] + mock_data_ys = get_json("weibo_ys_list_0.json") + text = mock_data_ys["data"]["cards"][2] + assert weibo.get_category(retweet) == 1 + assert weibo.get_category(video) == 2 + assert weibo.get_category(tuwen) == 3 + assert weibo.get_category(text) == 4 + @pytest.mark.asyncio @respx.mock async def test_parse_long(weibo): detail_router = respx.get("https://m.weibo.cn/detail/4645748019299849") - detail_router.mock(return_value=Response(200, text=get_file('weibo_detail_4645748019299849'))) - raw_post = get_json('weibo_ak_list_1.json')['data']['cards'][0] + detail_router.mock( + return_value=Response(200, text=get_file("weibo_detail_4645748019299849")) + ) + raw_post = get_json("weibo_ak_list_1.json")["data"]["cards"][0] post = await weibo.parse(raw_post) - assert(not '全文' in post.text) - assert(detail_router.called) + assert not "全文" in post.text + assert detail_router.called + def test_tag(weibo, weibo_ak_list_1): - raw_post = weibo_ak_list_1['data']['cards'][0] - assert(weibo.get_tags(raw_post) == ['明日方舟', '音律联觉']) + raw_post = weibo_ak_list_1["data"]["cards"][0] + assert weibo.get_tags(raw_post) == ["明日方舟", "音律联觉"] + @pytest.mark.asyncio @pytest.mark.compare async def test_rsshub_compare(weibo): - target = '6279793937' + target = "6279793937" raw_posts = filter(weibo.filter_platform_custom, await weibo.get_sub_list(target)) posts = [] for raw_post in raw_posts: posts.append(await weibo.parse(raw_post)) url_set = set(map(lambda x: x.url, posts)) - feedres = feedparser.parse('https://rsshub.app/weibo/user/6279793937') + feedres = feedparser.parse("https://rsshub.app/weibo/user/6279793937") for entry in feedres.entries[:5]: # print(entry) - assert(entry.link in url_set) + assert entry.link in url_set + test_post = { - "mblog": { - "text": "#刚出生的小羊驼长啥样#
小羊驼三三来也[好喜欢]
小羊驼三三 ", - "bid": "KnssqeqKK" - } + "mblog": { + "text": '#刚出生的小羊驼长啥样#
小羊驼三三来也[好喜欢]
小羊驼三三 ', + "bid": "KnssqeqKK", + } } + + def test_chaohua_tag(weibo): tags = weibo.get_tags(test_post) - assert('刚出生的小羊驼长啥样' in tags) - assert('小羊驼三三超话' in tags) - + assert "刚出生的小羊驼长啥样" in tags + assert "小羊驼三三超话" in tags diff --git a/tests/platforms/utils.py b/tests/platforms/utils.py index da6d0e9..b80374c 100644 --- a/tests/platforms/utils.py +++ b/tests/platforms/utils.py @@ -1,12 +1,16 @@ -from pathlib import Path import json +from pathlib import Path + path = Path(__file__).parent + + def get_json(file_name: str): - with open(path / file_name, 'r') as f: - file_text = f.read() + with open(path / file_name, "r") as f: + file_text = f.read() return json.loads(file_text) -def get_file(file_name:str): - with open(path / file_name, 'r') as f: - file_text = f.read() + +def get_file(file_name: str): + with open(path / file_name, "r") as f: + file_text = f.read() return file_text diff --git a/tests/test_config_manager.py b/tests/test_config_manager.py index 5ae7fe2..bd7a638 100644 --- a/tests/test_config_manager.py +++ b/tests/test_config_manager.py @@ -1,38 +1,47 @@ -import pytest import typing +import pytest + if typing.TYPE_CHECKING: import sys - sys.path.append('./src/plugins') + + sys.path.append("./src/plugins") import nonebot_bison + @pytest.fixture def config(plugin_module): plugin_module.config.start_up() return plugin_module.config.Config() -def test_create_and_get(config: 'nonebot_bison.config.Config', plugin_module: 'nonebot_bison'): + +def test_create_and_get( + config: "nonebot_bison.config.Config", plugin_module: "nonebot_bison" +): config.add_subscribe( - user='123', - user_type='group', - target='weibo_id', - target_name='weibo_name', - target_type='weibo', - cats=[], - tags=[]) - confs = config.list_subscribe('123', 'group') - assert(len(confs) == 1) - assert(config.target_user_cache['weibo']['weibo_id'] == \ - [plugin_module.types.User('123', 'group')]) - assert(confs[0]['cats'] == []) + user="123", + user_type="group", + target="weibo_id", + target_name="weibo_name", + target_type="weibo", + cats=[], + tags=[], + ) + confs = config.list_subscribe("123", "group") + assert len(confs) == 1 + assert config.target_user_cache["weibo"]["weibo_id"] == [ + plugin_module.types.User("123", "group") + ] + assert confs[0]["cats"] == [] config.update_subscribe( - user='123', - user_type='group', - target='weibo_id', - target_name='weibo_name', - target_type='weibo', - cats=['1'], - tags=[]) - confs = config.list_subscribe('123', 'group') - assert(len(confs) == 1) - assert(confs[0]['cats'] == ['1']) + user="123", + user_type="group", + target="weibo_id", + target_name="weibo_name", + target_type="weibo", + cats=["1"], + tags=[], + ) + confs = config.list_subscribe("123", "group") + assert len(confs) == 1 + assert confs[0]["cats"] == ["1"] diff --git a/tests/test_merge_pic.py b/tests/test_merge_pic.py index 6dcd2c2..2008f1f 100644 --- a/tests/test_merge_pic.py +++ b/tests/test_merge_pic.py @@ -1,75 +1,87 @@ -import pytest import typing +import pytest + if typing.TYPE_CHECKING: import sys - sys.path.append('./src/plugins') + + sys.path.append("./src/plugins") import nonebot_bison merge_source_9 = [ - 'https://wx1.sinaimg.cn/large/0071VPLMgy1gq0vib7zooj30dx0dxmz5.jpg', - "https://wx4.sinaimg.cn/large/0071VPLMgy1gq0vib5oqjj30dw0dxjt2.jpg", - "https://wx2.sinaimg.cn/large/0071VPLMgy1gq0vib8bjmj30dv0dxgn7.jpg", - "https://wx1.sinaimg.cn/large/0071VPLMgy1gq0vib6pn1j30dx0dw75v.jpg", - "https://wx4.sinaimg.cn/large/0071VPLMgy1gq0vib925mj30dw0dwabb.jpg", - "https://wx2.sinaimg.cn/large/0071VPLMgy1gq0vib7ujuj30dv0dwtap.jpg", - "https://wx1.sinaimg.cn/large/0071VPLMgy1gq0vibaexnj30dx0dvq49.jpg", - "https://wx1.sinaimg.cn/large/0071VPLMgy1gq0vibehw4j30dw0dv74u.jpg", - "https://wx1.sinaimg.cn/large/0071VPLMgy1gq0vibfb5fj30dv0dvtac.jpg", - "https://wx3.sinaimg.cn/large/0071VPLMgy1gq0viexkjxj30rs3pcx6p.jpg", - "https://wx2.sinaimg.cn/large/0071VPLMgy1gq0vif6qrpj30rs4mou10.jpg", - "https://wx4.sinaimg.cn/large/0071VPLMgy1gq0vifc826j30rs4a64qs.jpg", - "https://wx1.sinaimg.cn/large/0071VPLMgy1gq0vify21lj30rsbj71ld.jpg", - ] + "https://wx1.sinaimg.cn/large/0071VPLMgy1gq0vib7zooj30dx0dxmz5.jpg", + "https://wx4.sinaimg.cn/large/0071VPLMgy1gq0vib5oqjj30dw0dxjt2.jpg", + "https://wx2.sinaimg.cn/large/0071VPLMgy1gq0vib8bjmj30dv0dxgn7.jpg", + "https://wx1.sinaimg.cn/large/0071VPLMgy1gq0vib6pn1j30dx0dw75v.jpg", + "https://wx4.sinaimg.cn/large/0071VPLMgy1gq0vib925mj30dw0dwabb.jpg", + "https://wx2.sinaimg.cn/large/0071VPLMgy1gq0vib7ujuj30dv0dwtap.jpg", + "https://wx1.sinaimg.cn/large/0071VPLMgy1gq0vibaexnj30dx0dvq49.jpg", + "https://wx1.sinaimg.cn/large/0071VPLMgy1gq0vibehw4j30dw0dv74u.jpg", + "https://wx1.sinaimg.cn/large/0071VPLMgy1gq0vibfb5fj30dv0dvtac.jpg", + "https://wx3.sinaimg.cn/large/0071VPLMgy1gq0viexkjxj30rs3pcx6p.jpg", + "https://wx2.sinaimg.cn/large/0071VPLMgy1gq0vif6qrpj30rs4mou10.jpg", + "https://wx4.sinaimg.cn/large/0071VPLMgy1gq0vifc826j30rs4a64qs.jpg", + "https://wx1.sinaimg.cn/large/0071VPLMgy1gq0vify21lj30rsbj71ld.jpg", +] merge_source_9_2 = [ - 'https://wx2.sinaimg.cn/large/0071VPLMgy1gxo0eyycd7j30dw0dd3zk.jpg', - 'https://wx1.sinaimg.cn/large/0071VPLMgy1gxo0eyx6mhj30dw0ddjs8.jpg', - 'https://wx4.sinaimg.cn/large/0071VPLMgy1gxo0eyxf2bj30dw0dddh4.jpg', - 'https://wx3.sinaimg.cn/large/0071VPLMgy1gxo0ez1h5zj30dw0efwfs.jpg', - 'https://wx3.sinaimg.cn/large/0071VPLMgy1gxo0eyyku4j30dw0ef3zm.jpg', - 'https://wx1.sinaimg.cn/large/0071VPLMgy1gxo0ez0bjhj30dw0efabs.jpg', - 'https://wx4.sinaimg.cn/large/0071VPLMgy1gxo0ezdcafj30dw0dwacb.jpg', - 'https://wx1.sinaimg.cn/large/0071VPLMgy1gxo0ezg2g3j30dw0dwq51.jpg', - 'https://wx3.sinaimg.cn/large/0071VPLMgy1gxo0ez5oloj30dw0dw0uf.jpg', - 'https://wx4.sinaimg.cn/large/0071VPLMgy1gxo0fnk6stj30rs44ne81.jpg', - 'https://wx2.sinaimg.cn/large/0071VPLMgy1gxo0fohgcoj30rs3wpe81.jpg', - 'https://wx3.sinaimg.cn/large/0071VPLMgy1gxo0fpr6chj30rs3m1b29.jpg' - ] + "https://wx2.sinaimg.cn/large/0071VPLMgy1gxo0eyycd7j30dw0dd3zk.jpg", + "https://wx1.sinaimg.cn/large/0071VPLMgy1gxo0eyx6mhj30dw0ddjs8.jpg", + "https://wx4.sinaimg.cn/large/0071VPLMgy1gxo0eyxf2bj30dw0dddh4.jpg", + "https://wx3.sinaimg.cn/large/0071VPLMgy1gxo0ez1h5zj30dw0efwfs.jpg", + "https://wx3.sinaimg.cn/large/0071VPLMgy1gxo0eyyku4j30dw0ef3zm.jpg", + "https://wx1.sinaimg.cn/large/0071VPLMgy1gxo0ez0bjhj30dw0efabs.jpg", + "https://wx4.sinaimg.cn/large/0071VPLMgy1gxo0ezdcafj30dw0dwacb.jpg", + "https://wx1.sinaimg.cn/large/0071VPLMgy1gxo0ezg2g3j30dw0dwq51.jpg", + "https://wx3.sinaimg.cn/large/0071VPLMgy1gxo0ez5oloj30dw0dw0uf.jpg", + "https://wx4.sinaimg.cn/large/0071VPLMgy1gxo0fnk6stj30rs44ne81.jpg", + "https://wx2.sinaimg.cn/large/0071VPLMgy1gxo0fohgcoj30rs3wpe81.jpg", + "https://wx3.sinaimg.cn/large/0071VPLMgy1gxo0fpr6chj30rs3m1b29.jpg", +] + @pytest.mark.asyncio -async def test_9_merge(plugin_module: 'nonebot_bison'): - post = plugin_module.post.Post('', '', '', pics=merge_source_9) - await post._pic_merge() +async def test_9_merge(plugin_module: "nonebot_bison"): + post = plugin_module.post.Post("", "", "", pics=merge_source_9) + await post._pic_merge() assert len(post.pics) == 5 await post.generate_messages() + @pytest.mark.asyncio -async def test_9_merge_2(plugin_module: 'nonebot_bison'): - post = plugin_module.post.Post('', '', '', pics=merge_source_9_2) +async def test_9_merge_2(plugin_module: "nonebot_bison"): + post = plugin_module.post.Post("", "", "", pics=merge_source_9_2) await post._pic_merge() assert len(post.pics) == 4 await post.generate_messages() + @pytest.mark.asyncio async def test_6_merge(plugin_module): - post = plugin_module.post.Post('', '', '', pics=merge_source_9[0:6]+merge_source_9[9:]) - await post._pic_merge() + post = plugin_module.post.Post( + "", "", "", pics=merge_source_9[0:6] + merge_source_9[9:] + ) + await post._pic_merge() assert len(post.pics) == 5 + @pytest.mark.asyncio async def test_3_merge(plugin_module): - post = plugin_module.post.Post('', '', '', pics=merge_source_9[0:3]+merge_source_9[9:]) - await post._pic_merge() + post = plugin_module.post.Post( + "", "", "", pics=merge_source_9[0:3] + merge_source_9[9:] + ) + await post._pic_merge() assert len(post.pics) == 5 + @pytest.mark.asyncio async def test_6_merge_only(plugin_module): - post = plugin_module.post.Post('', '', '', pics=merge_source_9[0:6]) - await post._pic_merge() + post = plugin_module.post.Post("", "", "", pics=merge_source_9[0:6]) + await post._pic_merge() assert len(post.pics) == 1 + @pytest.mark.asyncio async def test_3_merge_only(plugin_module): - post = plugin_module.post.Post('', '', '', pics=merge_source_9[0:3]) - await post._pic_merge() + post = plugin_module.post.Post("", "", "", pics=merge_source_9[0:3]) + await post._pic_merge() assert len(post.pics) == 1 diff --git a/tests/test_render.py b/tests/test_render.py index 0d180ff..edeb478 100644 --- a/tests/test_render.py +++ b/tests/test_render.py @@ -1,20 +1,25 @@ -import pytest import typing +import pytest + if typing.TYPE_CHECKING: import sys - sys.path.append('./src/plugins') + + sys.path.append("./src/plugins") import nonebot_bison + @pytest.mark.asyncio @pytest.mark.render -async def test_render(plugin_module: 'nonebot_bison'): +async def test_render(plugin_module: "nonebot_bison"): render = plugin_module.utils.Render() - res = await render.text_to_pic('''a\nbbbbbbbbbbbbbbbbbbbbbb\ncd + res = await render.text_to_pic( + """a\nbbbbbbbbbbbbbbbbbbbbbb\ncd

中文

VuePress 由两部分组成:第一部分是一个极简静态网站生成器 (opens new window),它包含由 Vue 驱动的主题系统和插件 API,另一个部分是为书写技术文档而优化的默认主题,它的诞生初衷是为了支持 Vue 及其子项目的文档需求。 每一个由 VuePress 生成的页面都带有预渲染好的 HTML,也因此具有非常好的加载性能和搜索引擎优化(SEO)。同时,一旦页面被加载,Vue 将接管这些静态内容,并将其转换成一个完整的单页应用(SPA),其他的页面则会只在用户浏览到的时候才按需加载。 -''') +""" + )