diff --git a/nonebot_bison/__init__.py b/nonebot_bison/__init__.py index 3b7db3f..46dde3f 100644 --- a/nonebot_bison/__init__.py +++ b/nonebot_bison/__init__.py @@ -6,25 +6,18 @@ require("nonebot_plugin_saa") import nonebot_plugin_saa -from . import ( - admin_page, - bootstrap, - config, - platform, - post, - scheduler, - send, - sub_manager, - types, - utils, -) from .plugin_config import PlugConfig, plugin_config +from . import post, send, types, utils, config, platform, bootstrap, scheduler, admin_page, sub_manager __help__version__ = "0.7.3" nonebot_plugin_saa.enable_auto_select_bot() __help__plugin__name__ = "nonebot_bison" -__usage__ = f"本bot可以提供b站、微博等社交媒体的消息订阅,详情请查看本bot文档,或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,发送“查询订阅”或“删除订阅”管理订阅" +__usage__ = ( + "本bot可以提供b站、微博等社交媒体的消息订阅,详情请查看本bot文档," + f"或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号," + f"发送“查询订阅”或“删除订阅”管理订阅" +) __supported_adapters__ = nonebot_plugin_saa.__plugin_meta__.supported_adapters @@ -41,6 +34,7 @@ __plugin_meta__ = PluginMetadata( __all__ = [ "admin_page", + "bootstrap", "config", "sub_manager", "post", diff --git a/nonebot_bison/admin_page/__init__.py b/nonebot_bison/admin_page/__init__.py index b6ada13..661c57e 100644 --- a/nonebot_bison/admin_page/__init__.py +++ b/nonebot_bison/admin_page/__init__.py @@ -1,17 +1,16 @@ import os from pathlib import Path -from typing import Union -from nonebot import get_driver, on_command -from nonebot.adapters.onebot.v11 import Bot -from nonebot.adapters.onebot.v11.event import PrivateMessageEvent -from nonebot.drivers.fastapi import Driver from nonebot.log import logger from nonebot.rule import to_me from nonebot.typing import T_State +from nonebot import get_driver, on_command +from nonebot.drivers.fastapi import Driver +from nonebot.adapters.onebot.v11 import Bot +from nonebot.adapters.onebot.v11.event import PrivateMessageEvent -from ..plugin_config import plugin_config from .api import router as api_router +from ..plugin_config import plugin_config from .token_manager import token_manager as tm STATIC_PATH = (Path(__file__).parent / "dist").resolve() @@ -28,11 +27,9 @@ def init_fastapi(): class SinglePageApplication(StaticFiles): def __init__(self, directory: os.PathLike, index="index.html"): self.index = index - super().__init__( - directory=directory, packages=None, html=True, check_dir=True - ) + super().__init__(directory=directory, packages=None, html=True, check_dir=True) - def lookup_path(self, path: str) -> tuple[str, Union[os.stat_result, None]]: + def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]: full_path, stat_res = super().lookup_path(path) if stat_res is None: return super().lookup_path(self.index) @@ -45,9 +42,7 @@ def init_fastapi(): description="nonebot-bison webui and api", ) nonebot_app.include_router(api_router) - nonebot_app.mount( - "/", SinglePageApplication(directory=static_path), name="bison-frontend" - ) + nonebot_app.mount("/", SinglePageApplication(directory=static_path), name="bison-frontend") app = driver.server_app app.mount("/bison", nonebot_app, "nonebot-bison") @@ -63,10 +58,9 @@ def init_fastapi(): if host in ["0.0.0.0", "127.0.0.1"]: host = "localhost" logger.opt(colors=True).info( - f"Nonebot Bison frontend will be running at: " - f"http://{host}:{port}/bison" + f"Nonebot Bison frontend will be running at: " f"http://{host}:{port}/bison" ) - logger.opt(colors=True).info(f"该页面不能被直接访问,请私聊bot 后台管理 以获取可访问地址") + logger.opt(colors=True).info("该页面不能被直接访问,请私聊bot 后台管理 以获取可访问地址") def register_get_token_handler(): @@ -93,6 +87,4 @@ if (STATIC_PATH / "index.html").exists(): else: logger.warning("your driver is not fastapi, webui feature will be disabled") else: - logger.warning( - "Frontend file not found, please compile it or use docker or pypi version" - ) + logger.warning("Frontend file not found, please compile it or use docker or pypi version") diff --git a/nonebot_bison/admin_page/api.py b/nonebot_bison/admin_page/api.py index 5a4fb81..4c299af 100644 --- a/nonebot_bison/admin_page/api.py +++ b/nonebot_bison/admin_page/api.py @@ -1,35 +1,30 @@ import nonebot from fastapi import status -from fastapi.exceptions import HTTPException -from fastapi.param_functions import Depends from fastapi.routing import APIRouter -from fastapi.security.oauth2 import OAuth2PasswordBearer +from fastapi.param_functions import Depends +from fastapi.exceptions import HTTPException from nonebot_plugin_saa import TargetQQGroup +from fastapi.security.oauth2 import OAuth2PasswordBearer from nonebot_plugin_saa.utils.auto_select_bot import get_bot -from ..apis import check_sub_target -from ..config import ( - NoSuchSubscribeException, - NoSuchTargetException, - NoSuchUserException, - config, -) -from ..config.db_config import SubscribeDupException -from ..platform import platform_manager -from ..types import Target as T_Target from ..types import WeightConfig -from ..utils.get_bot import get_groups +from ..apis import check_sub_target from .jwt import load_jwt, pack_jwt +from ..types import Target as T_Target +from ..utils.get_bot import get_groups +from ..platform import platform_manager from .token_manager import token_manager +from ..config.db_config import SubscribeDupException +from ..config import NoSuchUserException, NoSuchTargetException, NoSuchSubscribeException, config from .types import ( - AddSubscribeReq, + TokenResp, GlobalConf, - PlatformConfig, StatusResp, + SubscribeResp, + PlatformConfig, + AddSubscribeReq, SubscribeConfig, SubscribeGroupDetail, - SubscribeResp, - TokenResp, ) router = APIRouter(prefix="/api", tags=["api"]) @@ -44,9 +39,7 @@ async def get_jwt_obj(token: str = Depends(oath_scheme)): return obj -async def check_group_permission( - groupNumber: int, token_obj: dict = Depends(get_jwt_obj) -): +async def check_group_permission(groupNumber: int, token_obj: dict = Depends(get_jwt_obj)): groups = token_obj["groups"] for group in groups: if int(groupNumber) == group["id"]: @@ -95,15 +88,13 @@ async def auth(token: str) -> TokenResp: jwt_obj = { "id": qq, "type": "admin", - "groups": list( - map( - lambda info: { - "id": info["group_id"], - "name": info["group_name"], - }, - await get_groups(), - ) - ), + "groups": [ + { + "id": info["group_id"], + "name": info["group_name"], + } + for info in await get_groups() + ], } ret_obj = TokenResp( type="admin", @@ -134,18 +125,16 @@ async def get_subs_info(jwt_obj: dict = Depends(get_jwt_obj)) -> SubscribeResp: for group in groups: group_id = group["id"] raw_subs = await config.list_subscribe(TargetQQGroup(group_id=group_id)) - subs = list( - map( - lambda sub: SubscribeConfig( - platformName=sub.target.platform_name, - targetName=sub.target.target_name, - cats=sub.categories, - tags=sub.tags, - target=sub.target.target, - ), - raw_subs, + subs = [ + SubscribeConfig( + platformName=sub.target.platform_name, + targetName=sub.target.target_name, + cats=sub.categories, + tags=sub.tags, + target=sub.target.target, ) - ) + for sub in raw_subs + ] res[group_id] = SubscribeGroupDetail(name=group["name"], subscribes=subs) return res @@ -174,9 +163,7 @@ async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp: @router.delete("/subs", dependencies=[Depends(check_group_permission)]) async def del_group_sub(groupNumber: int, platformName: str, target: str): try: - await config.del_subscribe( - TargetQQGroup(group_id=groupNumber), target, platformName - ) + await config.del_subscribe(TargetQQGroup(group_id=groupNumber), target, platformName) except (NoSuchUserException, NoSuchSubscribeException): raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such user or subscribe") return StatusResp(ok=True, msg="") @@ -204,13 +191,9 @@ async def get_weight_config(): @router.put("/weight", dependencies=[Depends(check_is_superuser)]) -async def update_weigth_config( - platformName: str, target: str, weight_config: WeightConfig -): +async def update_weigth_config(platformName: str, target: str, weight_config: WeightConfig): try: - await config.update_time_weight_config( - T_Target(target), platformName, weight_config - ) + await config.update_time_weight_config(T_Target(target), platformName, weight_config) except NoSuchTargetException: raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such subscribe") return StatusResp(ok=True, msg="") diff --git a/nonebot_bison/admin_page/jwt.py b/nonebot_bison/admin_page/jwt.py index 661621a..866c184 100644 --- a/nonebot_bison/admin_page/jwt.py +++ b/nonebot_bison/admin_page/jwt.py @@ -1,7 +1,6 @@ -import datetime import random import string -from typing import Optional +import datetime import jwt @@ -16,8 +15,8 @@ def pack_jwt(obj: dict) -> str: ) -def load_jwt(token: str) -> Optional[dict]: +def load_jwt(token: str) -> dict | None: try: return jwt.decode(token, _key, algorithms=["HS256"]) - except: + except Exception: return None diff --git a/nonebot_bison/admin_page/token_manager.py b/nonebot_bison/admin_page/token_manager.py index e540656..bb62d0a 100644 --- a/nonebot_bison/admin_page/token_manager.py +++ b/nonebot_bison/admin_page/token_manager.py @@ -1,6 +1,5 @@ import random import string -from typing import Optional from expiringdict import ExpiringDict @@ -9,7 +8,7 @@ class TokenManager: def __init__(self): self.token_manager = ExpiringDict(max_len=100, max_age_seconds=60 * 10) - def get_user(self, token: str) -> Optional[tuple]: + def get_user(self, token: str) -> tuple | None: res = self.token_manager.get(token) assert res is None or isinstance(res, tuple) return res diff --git a/nonebot_bison/config/__init__.py b/nonebot_bison/config/__init__.py index 2fb9151..a04d41f 100644 --- a/nonebot_bison/config/__init__.py +++ b/nonebot_bison/config/__init__.py @@ -1,2 +1,4 @@ -from .db_config import config -from .utils import NoSuchSubscribeException, NoSuchTargetException, NoSuchUserException +from .db_config import config as config +from .utils import NoSuchUserException as NoSuchUserException +from .utils import NoSuchTargetException as NoSuchTargetException +from .utils import NoSuchSubscribeException as NoSuchSubscribeException diff --git a/nonebot_bison/config/config_legacy.py b/nonebot_bison/config/config_legacy.py index d892b5c..24e7e4d 100644 --- a/nonebot_bison/config/config_legacy.py +++ b/nonebot_bison/config/config_legacy.py @@ -1,20 +1,19 @@ -import json import os -from collections import defaultdict -from datetime import datetime +import json from os import path from pathlib import Path -from typing import DefaultDict, Literal, Mapping, TypedDict +from datetime import datetime +from collections import defaultdict +from typing import Literal, TypedDict -import nonebot from nonebot.log import logger from tinydb import Query, TinyDB +from ..utils import Singleton +from ..types import User, Target from ..platform import platform_manager from ..plugin_config import plugin_config -from ..types import Target, User -from ..utils import Singleton -from .utils import NoSuchSubscribeException, NoSuchUserException +from .utils import NoSuchUserException, NoSuchSubscribeException supported_target_type = platform_manager.keys() @@ -89,17 +88,16 @@ class Config(metaclass=Singleton): self.target_user_cat_cache = {} self.target_user_tag_cache = {} self.target_list = {} - self.next_index: DefaultDict[str, int] = defaultdict(lambda: 0) + self.next_index: defaultdict[str, int] = defaultdict(lambda: 0) else: self.available = False - def add_subscribe( - self, user, user_type, target, target_name, target_type, cats, tags - ): + def add_subscribe(self, user, user_type, target, target_name, target_type, cats, tags): user_query = Query() query = (user_query.user == user) & (user_query.user_type == user_type) if user_data := self.user_target.get(query): # update + assert not isinstance(user_data, list) subs: list = user_data.get("subs", []) subs.append( { @@ -132,9 +130,8 @@ class Config(metaclass=Singleton): def list_subscribe(self, user, user_type) -> list[SubscribeContent]: query = Query() - if user_sub := self.user_target.get( - (query.user == user) & (query.user_type == user_type) - ): + if user_sub := self.user_target.get((query.user == user) & (query.user_type == user_type)): + assert not isinstance(user_sub, list) return user_sub["subs"] return [] @@ -146,6 +143,7 @@ class Config(metaclass=Singleton): query = (user_query.user == user) & (user_query.user_type == user_type) if not (query_res := self.user_target.get(query)): raise NoSuchUserException() + assert not isinstance(query_res, list) subs = query_res.get("subs", []) for idx, sub in enumerate(subs): if sub.get("target") == target and sub.get("target_type") == target_type: @@ -155,13 +153,12 @@ class Config(metaclass=Singleton): return raise NoSuchSubscribeException() - def update_subscribe( - self, user, user_type, target, target_name, target_type, cats, tags - ): + def update_subscribe(self, user, user_type, target, target_name, target_type, cats, tags): user_query = Query() query = (user_query.user == user) & (user_query.user_type == user_type) if user_data := self.user_target.get(query): # update + assert not isinstance(user_data, list) subs: list = user_data.get("subs", []) find_flag = False for item in subs: @@ -182,19 +179,13 @@ class Config(metaclass=Singleton): def update_send_cache(self): res = {target_type: defaultdict(list) for target_type in supported_target_type} - cat_res = { - target_type: defaultdict(lambda: defaultdict(list)) - for target_type in supported_target_type - } - tag_res = { - target_type: defaultdict(lambda: defaultdict(list)) - for target_type in supported_target_type - } + cat_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} + tag_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} # res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} to_del = [] for user in self.user_target.all(): for sub in user.get("subs", []): - if not sub.get("target_type") in supported_target_type: + if sub.get("target_type") not in supported_target_type: to_del.append( { "user": user["user"], @@ -204,36 +195,28 @@ class Config(metaclass=Singleton): } ) continue - res[sub["target_type"]][sub["target"]].append( - User(user["user"], user["user_type"]) - ) - cat_res[sub["target_type"]][sub["target"]][ - "{}-{}".format(user["user_type"], user["user"]) - ] = sub["cats"] - tag_res[sub["target_type"]][sub["target"]][ - "{}-{}".format(user["user_type"], user["user"]) - ] = sub["tags"] + res[sub["target_type"]][sub["target"]].append(User(user["user"], user["user_type"])) + cat_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[ + "cats" + ] + tag_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[ + "tags" + ] self.target_user_cache = res self.target_user_cat_cache = cat_res self.target_user_tag_cache = tag_res for target_type in self.target_user_cache: - self.target_list[target_type] = list( - self.target_user_cache[target_type].keys() - ) + self.target_list[target_type] = list(self.target_user_cache[target_type].keys()) logger.info(f"Deleting {to_del}") for d in to_del: self.del_subscribe(**d) def get_sub_category(self, target_type, target, user_type, user): - return self.target_user_cat_cache[target_type][target][ - "{}-{}".format(user_type, user) - ] + return self.target_user_cat_cache[target_type][target][f"{user_type}-{user}"] def get_sub_tags(self, target_type, target, user_type, user): - return self.target_user_tag_cache[target_type][target][ - "{}-{}".format(user_type, user) - ] + return self.target_user_tag_cache[target_type][target][f"{user_type}-{user}"] def get_next_target(self, target_type): # FIXME 插入或删除target后对队列的影响(但是并不是大问题 diff --git a/nonebot_bison/config/db_config.py b/nonebot_bison/config/db_config.py index 38ef9af..ef6cc6b 100644 --- a/nonebot_bison/config/db_config.py +++ b/nonebot_bison/config/db_config.py @@ -1,19 +1,19 @@ import asyncio from collections import defaultdict -from datetime import datetime, time -from typing import Awaitable, Callable, Optional, Sequence +from datetime import time, datetime +from collections.abc import Callable, Sequence, Awaitable -from nonebot_plugin_datastore import create_session -from nonebot_plugin_saa import PlatformTarget -from sqlalchemy import delete, func, select -from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import selectinload +from sqlalchemy.exc import IntegrityError +from sqlalchemy import func, delete, select +from nonebot_plugin_saa import PlatformTarget +from nonebot_plugin_datastore import create_session -from ..types import Category, PlatformWeightConfigResp, Tag +from ..types import Tag from ..types import Target as T_Target -from ..types import TimeWeightConfig, UserSubInfo, WeightConfig -from .db_model import ScheduleTimeWeight, Subscribe, Target, User from .utils import NoSuchTargetException +from .db_model import User, Target, Subscribe, ScheduleTimeWeight +from ..types import Category, UserSubInfo, WeightConfig, TimeWeightConfig, PlatformWeightConfigResp def _get_time(): @@ -48,23 +48,17 @@ class DBConfig: ): async with create_session() as session: db_user_stmt = select(User).where(User.user_target == user.dict()) - db_user: Optional[User] = await session.scalar(db_user_stmt) + db_user: User | None = await session.scalar(db_user_stmt) if not db_user: db_user = User(user_target=user.dict()) session.add(db_user) db_target_stmt = ( - select(Target) - .where(Target.platform_name == platform_name) - .where(Target.target == target) + select(Target).where(Target.platform_name == platform_name).where(Target.target == target) ) - db_target: Optional[Target] = await session.scalar(db_target_stmt) + db_target: Target | None = await session.scalar(db_target_stmt) if not db_target: - db_target = Target( - target=target, platform_name=platform_name, target_name=target_name - ) - await asyncio.gather( - *[hook(platform_name, target) for hook in self.add_target_hook] - ) + db_target = Target(target=target, platform_name=platform_name, target_name=target_name) + await asyncio.gather(*[hook(platform_name, target) for hook in self.add_target_hook]) else: db_target.target_name = target_name subscribe = Subscribe( @@ -96,44 +90,25 @@ class DBConfig: """获取数据库中带有user、target信息的subscribe数据""" async with create_session() as session: query_stmt = ( - select(Subscribe) - .join(User) - .options(selectinload(Subscribe.target), selectinload(Subscribe.user)) + select(Subscribe).join(User).options(selectinload(Subscribe.target), selectinload(Subscribe.user)) ) subs = (await session.scalars(query_stmt)).all() return subs - async def del_subscribe( - self, user: PlatformTarget, target: str, platform_name: str - ): + async def del_subscribe(self, user: PlatformTarget, target: str, platform_name: str): async with create_session() as session: - user_obj = await session.scalar( - select(User).where(User.user_target == user.dict()) - ) + user_obj = await session.scalar(select(User).where(User.user_target == user.dict())) target_obj = await session.scalar( - select(Target).where( - Target.platform_name == platform_name, Target.target == target - ) - ) - await session.execute( - delete(Subscribe).where( - Subscribe.user == user_obj, Subscribe.target == target_obj - ) + select(Target).where(Target.platform_name == platform_name, Target.target == target) ) + await session.execute(delete(Subscribe).where(Subscribe.user == user_obj, Subscribe.target == target_obj)) target_count = await session.scalar( - select(func.count()) - .select_from(Subscribe) - .where(Subscribe.target == target_obj) + select(func.count()).select_from(Subscribe).where(Subscribe.target == target_obj) ) if target_count == 0: # delete empty target - await asyncio.gather( - *[ - hook(platform_name, T_Target(target)) - for hook in self.delete_target_hook - ] - ) + await asyncio.gather(*[hook(platform_name, T_Target(target)) for hook in self.delete_target_hook]) await session.commit() async def update_subscribe( @@ -165,29 +140,22 @@ class DBConfig: async def get_platform_target(self, platform_name: str) -> Sequence[Target]: async with create_session() as sess: subq = select(Subscribe.target_id).distinct().subquery() - query = ( - select(Target).join(subq).where(Target.platform_name == platform_name) - ) + query = select(Target).join(subq).where(Target.platform_name == platform_name) return (await sess.scalars(query)).all() - async def get_time_weight_config( - self, target: T_Target, platform_name: str - ) -> WeightConfig: + async def get_time_weight_config(self, target: T_Target, platform_name: str) -> WeightConfig: async with create_session() as sess: time_weight_conf = ( await sess.scalars( select(ScheduleTimeWeight) - .where( - Target.platform_name == platform_name, Target.target == target - ) + .where(Target.platform_name == platform_name, Target.target == target) .join(Target) ) ).all() targetObj = await sess.scalar( - select(Target).where( - Target.platform_name == platform_name, Target.target == target - ) + select(Target).where(Target.platform_name == platform_name, Target.target == target) ) + assert targetObj return WeightConfig( default=targetObj.default_schedule_weight, time_config=[ @@ -200,22 +168,16 @@ class DBConfig: ], ) - async def update_time_weight_config( - self, target: T_Target, platform_name: str, conf: WeightConfig - ): + async def update_time_weight_config(self, target: T_Target, platform_name: str, conf: WeightConfig): async with create_session() as sess: targetObj = await sess.scalar( - select(Target).where( - Target.platform_name == platform_name, Target.target == target - ) + select(Target).where(Target.platform_name == platform_name, Target.target == target) ) if not targetObj: raise NoSuchTargetException() target_id = targetObj.id targetObj.default_schedule_weight = conf.default - delete_statement = delete(ScheduleTimeWeight).where( - ScheduleTimeWeight.target_id == target_id - ) + delete_statement = delete(ScheduleTimeWeight).where(ScheduleTimeWeight.target_id == target_id) await sess.execute(delete_statement) for time_conf in conf.time_config: new_conf = ScheduleTimeWeight( @@ -243,18 +205,13 @@ class DBConfig: key = f"{target.platform_name}-{target.target}" weight = target.default_schedule_weight for time_conf in target.time_weight: - if ( - time_conf.start_time <= cur_time - and time_conf.end_time > cur_time - ): + if time_conf.start_time <= cur_time and time_conf.end_time > cur_time: weight = time_conf.weight break res[key] = weight return res - async def get_platform_target_subscribers( - self, platform_name: str, target: T_Target - ) -> list[UserSubInfo]: + async def get_platform_target_subscribers(self, platform_name: str, target: T_Target) -> list[UserSubInfo]: async with create_session() as sess: query = ( select(Subscribe) @@ -263,16 +220,14 @@ class DBConfig: .options(selectinload(Subscribe.user)) ) subsribes = (await sess.scalars(query)).all() - return list( - map( - lambda subscribe: UserSubInfo( - PlatformTarget.deserialize(subscribe.user.user_target), - subscribe.categories, - subscribe.tags, - ), - subsribes, + return [ + UserSubInfo( + PlatformTarget.deserialize(subscribe.user.user_target), + subscribe.categories, + subscribe.tags, ) - ) + for subscribe in subsribes + ] async def get_all_weight_config( self, @@ -281,9 +236,7 @@ class DBConfig: async with create_session() as sess: query = select(Target) targets = (await sess.scalars(query)).all() - query = select(ScheduleTimeWeight).options( - selectinload(ScheduleTimeWeight.target) - ) + query = select(ScheduleTimeWeight).options(selectinload(ScheduleTimeWeight.target)) time_weights = (await sess.scalars(query)).all() for target in targets: @@ -293,9 +246,7 @@ class DBConfig: target=T_Target(target.target), target_name=target.target_name, platform_name=platform_name, - weight=WeightConfig( - default=target.default_schedule_weight, time_config=[] - ), + weight=WeightConfig(default=target.default_schedule_weight, time_config=[]), ) for time_weight_config in time_weights: diff --git a/nonebot_bison/config/db_migration.py b/nonebot_bison/config/db_migration.py index e20b1dc..08d3117 100644 --- a/nonebot_bison/config/db_migration.py +++ b/nonebot_bison/config/db_migration.py @@ -1,22 +1,17 @@ from nonebot.log import logger from nonebot_plugin_datastore.db import get_engine -from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate from sqlalchemy.ext.asyncio.session import AsyncSession +from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate +from .db_model import User, Target, Subscribe from .config_legacy import Config, ConfigContent, drop -from .db_model import Subscribe, Target, User async def data_migrate(): config = Config() if config.available: logger.warning("You are still using legacy db, migrating to sqlite") - all_subs: list[ConfigContent] = list( - map( - lambda item: ConfigContent(**item), - config.get_all_subscribe().all(), - ) - ) + all_subs: list[ConfigContent] = [ConfigContent(**item) for item in config.get_all_subscribe().all()] async with AsyncSession(get_engine()) as sess: user_to_create = [] subscribe_to_create = [] @@ -37,8 +32,7 @@ async def data_migrate(): if key in user_sub_set: # a user subscribe a target twice logger.error( - f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次," - "随机采用了一个订阅" + f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次,随机采用了一个订阅" # noqa: E501 ) continue user_sub_set.add(key) @@ -69,11 +63,7 @@ async def data_migrate(): tags=sub["tags"], ) subscribe_to_create.append(subscribe_obj) - sess.add_all( - user_to_create - + list(map(lambda x: x[0], platform_target_map.values())) - + subscribe_to_create - ) + sess.add_all(user_to_create + [x[0] for x in platform_target_map.values()] + subscribe_to_create) await sess.commit() drop() logger.info("migrate success") diff --git a/nonebot_bison/config/migrations/0571870f5222_init_db.py b/nonebot_bison/config/migrations/0571870f5222_init_db.py index d6e0c2c..347212a 100644 --- a/nonebot_bison/config/migrations/0571870f5222_init_db.py +++ b/nonebot_bison/config/migrations/0571870f5222_init_db.py @@ -1,7 +1,7 @@ """init db Revision ID: 0571870f5222 -Revises: +Revises: Create Date: 2022-03-21 19:18:13.762626 """ diff --git a/nonebot_bison/config/migrations/5da28f6facb3_rename_tables.py b/nonebot_bison/config/migrations/5da28f6facb3_rename_tables.py index c8eb5c7..20e7544 100644 --- a/nonebot_bison/config/migrations/5da28f6facb3_rename_tables.py +++ b/nonebot_bison/config/migrations/5da28f6facb3_rename_tables.py @@ -5,7 +5,6 @@ Revises: 5f3370328e44 Create Date: 2023-01-15 19:04:54.987491 """ -import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. diff --git a/nonebot_bison/config/migrations/c97c445e2bdb_add_constraint.py b/nonebot_bison/config/migrations/c97c445e2bdb_add_constraint.py index 9119d3b..807699e 100644 --- a/nonebot_bison/config/migrations/c97c445e2bdb_add_constraint.py +++ b/nonebot_bison/config/migrations/c97c445e2bdb_add_constraint.py @@ -5,7 +5,6 @@ Revises: 0571870f5222 Create Date: 2022-03-26 19:46:50.910721 """ -import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. @@ -18,14 +17,10 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### with op.batch_alter_table("subscribe", schema=None) as batch_op: - batch_op.create_unique_constraint( - "unique-subscribe-constraint", ["target_id", "user_id"] - ) + batch_op.create_unique_constraint("unique-subscribe-constraint", ["target_id", "user_id"]) with op.batch_alter_table("target", schema=None) as batch_op: - batch_op.create_unique_constraint( - "unique-target-constraint", ["target", "platform_name"] - ) + batch_op.create_unique_constraint("unique-target-constraint", ["target", "platform_name"]) with op.batch_alter_table("user", schema=None) as batch_op: batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"]) diff --git a/nonebot_bison/config/subs_io/nbesf_model/base.py b/nonebot_bison/config/subs_io/nbesf_model/base.py index 11ae2bb..f8e4b55 100644 --- a/nonebot_bison/config/subs_io/nbesf_model/base.py +++ b/nonebot_bison/config/subs_io/nbesf_model/base.py @@ -1,15 +1,14 @@ from abc import ABC -from nonebot_plugin_saa.utils import AllSupportedPlatformTarget as UserInfo from pydantic import BaseModel +from nonebot_plugin_saa.utils import AllSupportedPlatformTarget as UserInfo -from ....types import Category, Tag +from ....types import Tag, Category class NBESFBase(BaseModel, ABC): - version: int # 表示nbesf格式版本,有效版本从1开始 - groups: list = list() + groups: list = [] class Config: orm_mode = True diff --git a/nonebot_bison/config/subs_io/subs_io.py b/nonebot_bison/config/subs_io/subs_io.py index 9a16472..21c1310 100644 --- a/nonebot_bison/config/subs_io/subs_io.py +++ b/nonebot_bison/config/subs_io/subs_io.py @@ -1,25 +1,26 @@ +from typing import cast from collections import defaultdict -from typing import Callable, cast +from collections.abc import Callable -from nonebot.log import logger -from nonebot_plugin_datastore.db import create_session -from nonebot_plugin_saa import PlatformTarget from sqlalchemy import select -from sqlalchemy.orm.strategy_options import selectinload +from nonebot.log import logger from sqlalchemy.sql.selectable import Select +from nonebot_plugin_saa import PlatformTarget +from nonebot_plugin_datastore.db import create_session +from sqlalchemy.orm.strategy_options import selectinload -from ..db_model import Subscribe, User -from .nbesf_model import NBESFBase, v1, v2 from .utils import NBESFVerMatchErr +from ..db_model import User, Subscribe +from .nbesf_model import NBESFBase, v1, v2 async def subscribes_export(selector: Callable[[Select], Select]) -> v2.SubGroup: - """ 将Bison订阅导出为 Nonebot Bison Exchangable Subscribes File 标准格式的 SubGroup 类型数据 selector: - 对 sqlalchemy Select 对象的操作函数,用于限定查询范围 e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group") + 对 sqlalchemy Select 对象的操作函数,用于限定查询范围 + e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group") """ async with create_session() as sess: sub_stmt = select(Subscribe).join(User) diff --git a/nonebot_bison/platform/__init__.py b/nonebot_bison/platform/__init__.py index e8d7186..c99ce12 100644 --- a/nonebot_bison/platform/__init__.py +++ b/nonebot_bison/platform/__init__.py @@ -1,23 +1,22 @@ -from collections import defaultdict -from importlib import import_module from pathlib import Path from pkgutil import iter_modules -from typing import DefaultDict, Type +from collections import defaultdict +from importlib import import_module from .platform import Platform, make_no_target_group _package_dir = str(Path(__file__).resolve().parent) -for (_, module_name, _) in iter_modules([_package_dir]): +for _, module_name, _ in iter_modules([_package_dir]): import_module(f"{__name__}.{module_name}") -_platform_list: DefaultDict[str, list[Type[Platform]]] = defaultdict(list) +_platform_list: defaultdict[str, list[type[Platform]]] = defaultdict(list) for _platform in Platform.registry: if not _platform.enabled: continue _platform_list[_platform.platform_name].append(_platform) -platform_manager: dict[str, Type[Platform]] = dict() +platform_manager: dict[str, type[Platform]] = {} for name, platform_list in _platform_list.items(): if len(platform_list) == 1: platform_manager[name] = platform_list[0] diff --git a/nonebot_bison/platform/arknights.py b/nonebot_bison/platform/arknights.py index 2513de5..00f3fbd 100644 --- a/nonebot_bison/platform/arknights.py +++ b/nonebot_bison/platform/arknights.py @@ -1,25 +1,23 @@ import json -from typing import Any, Optional +from typing import Any -from bs4 import BeautifulSoup as bs from httpx import AsyncClient from nonebot.plugin import require +from bs4 import BeautifulSoup as bs from ..post import Post -from ..types import Category, RawPost, Target +from ..types import Target, RawPost, Category from ..utils.scheduler_config import SchedulerConfig -from .platform import CategoryNotRecognize, NewMessage, StatusChange +from .platform import NewMessage, StatusChange, CategoryNotRecognize class ArknightsSchedConf(SchedulerConfig): - name = "arknights" schedule_type = "interval" schedule_setting = {"seconds": 30} class Arknights(NewMessage): - categories = {1: "游戏公告"} platform_name = "arknights" name = "明日方舟游戏信息" @@ -30,9 +28,7 @@ class Arknights(NewMessage): has_target = False @classmethod - async def get_target_name( - cls, client: AsyncClient, target: Target - ) -> Optional[str]: + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: return "明日方舟游戏信息" async def get_sub_list(self, _) -> list[RawPost]: @@ -92,7 +88,6 @@ class Arknights(NewMessage): class AkVersion(StatusChange): - categories = {2: "更新信息"} platform_name = "arknights" name = "明日方舟游戏信息" @@ -103,15 +98,11 @@ class AkVersion(StatusChange): has_target = False @classmethod - async def get_target_name( - cls, client: AsyncClient, target: Target - ) -> Optional[str]: + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: return "明日方舟游戏信息" async def get_status(self, _): - res_ver = await self.client.get( - "https://ak-conf.hypergryph.com/config/prod/official/IOS/version" - ) + res_ver = await self.client.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version") res_preanounce = await self.client.get( "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json" ) @@ -121,20 +112,10 @@ class AkVersion(StatusChange): def compare_status(self, _, old_status, new_status): res = [] - if ( - old_status.get("preAnnounceType") == 2 - and new_status.get("preAnnounceType") == 0 - ): - res.append( - Post("arknights", text="登录界面维护公告上线(大概是开始维护了)", target_name="明日方舟更新信息") - ) - elif ( - old_status.get("preAnnounceType") == 0 - and new_status.get("preAnnounceType") == 2 - ): - res.append( - Post("arknights", text="登录界面维护公告下线(大概是开服了,冲!)", target_name="明日方舟更新信息") - ) + if old_status.get("preAnnounceType") == 2 and new_status.get("preAnnounceType") == 0: + res.append(Post("arknights", text="登录界面维护公告上线(大概是开始维护了)", target_name="明日方舟更新信息")) # noqa: E501 + elif old_status.get("preAnnounceType") == 0 and new_status.get("preAnnounceType") == 2: + res.append(Post("arknights", text="登录界面维护公告下线(大概是开服了,冲!)", target_name="明日方舟更新信息")) # noqa: E501 if old_status.get("clientVersion") != new_status.get("clientVersion"): res.append(Post("arknights", text="游戏本体更新(大更新)", target_name="明日方舟更新信息")) if old_status.get("resVersion") != new_status.get("resVersion"): @@ -149,7 +130,6 @@ class AkVersion(StatusChange): class MonsterSiren(NewMessage): - categories = {3: "塞壬唱片新闻"} platform_name = "arknights" name = "明日方舟游戏信息" @@ -160,15 +140,11 @@ class MonsterSiren(NewMessage): has_target = False @classmethod - async def get_target_name( - cls, client: AsyncClient, target: Target - ) -> Optional[str]: + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: return "明日方舟游戏信息" async def get_sub_list(self, _) -> list[RawPost]: - raw_data = await self.client.get( - "https://monster-siren.hypergryph.com/api/news" - ) + raw_data = await self.client.get("https://monster-siren.hypergryph.com/api/news") return raw_data.json()["data"]["list"] def get_id(self, post: RawPost) -> Any: @@ -182,14 +158,12 @@ class MonsterSiren(NewMessage): async def parse(self, raw_post: RawPost) -> Post: url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}' - res = await self.client.get( - f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}' - ) + res = await self.client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}') raw_data = res.json() content = raw_data["data"]["content"] content = content.replace("
", "\n") soup = bs(content, "html.parser") - imgs = list(map(lambda x: x["src"], soup("img"))) + imgs = [x["src"] for x in soup("img")] text = f'{raw_post["title"]}\n{soup.text.strip()}' return Post( "monster-siren", @@ -203,7 +177,6 @@ class MonsterSiren(NewMessage): class TerraHistoricusComic(NewMessage): - categories = {4: "泰拉记事社漫画"} platform_name = "arknights" name = "明日方舟游戏信息" @@ -214,15 +187,11 @@ class TerraHistoricusComic(NewMessage): has_target = False @classmethod - async def get_target_name( - cls, client: AsyncClient, target: Target - ) -> Optional[str]: + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: return "明日方舟游戏信息" async def get_sub_list(self, _) -> list[RawPost]: - raw_data = await self.client.get( - "https://terra-historicus.hypergryph.com/api/recentUpdate" - ) + raw_data = await self.client.get("https://terra-historicus.hypergryph.com/api/recentUpdate") return raw_data.json()["data"] def get_id(self, post: RawPost) -> Any: diff --git a/nonebot_bison/platform/bilibili.py b/nonebot_bison/platform/bilibili.py index 9769348..c81b0f1 100644 --- a/nonebot_bison/platform/bilibili.py +++ b/nonebot_bison/platform/bilibili.py @@ -1,14 +1,14 @@ -import json import re +import json +from typing import Any from copy import deepcopy -from datetime import datetime, timedelta from enum import Enum, unique -from typing import Any, Literal, Optional +from typing_extensions import Self +from datetime import datetime, timedelta from httpx import AsyncClient from nonebot.log import logger -from pydantic import BaseModel, Field -from typing_extensions import Self +from pydantic import Field, BaseModel from ..post import Post from ..types import ApiError, Category, RawPost, Tag, Target @@ -25,9 +25,7 @@ class BilibiliSchedConf(SchedulerConfig): cookie_expire_time = timedelta(hours=5) def __init__(self): - self._client_refresh_time = datetime( - year=2000, month=1, day=1 - ) # an expired time + self._client_refresh_time = datetime(year=2000, month=1, day=1) # an expired time super().__init__() async def _init_session(self): @@ -69,12 +67,8 @@ class Bilibili(NewMessage): parse_target_promot = "请输入用户主页的链接" @classmethod - async def get_target_name( - cls, client: AsyncClient, target: Target - ) -> Optional[str]: - res = await client.get( - "https://api.bilibili.com/x/web-interface/card", params={"mid": target} - ) + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: + res = await client.get("https://api.bilibili.com/x/web-interface/card", params={"mid": target}) res.raise_for_status() res_data = res.json() if res_data["code"]: @@ -129,12 +123,7 @@ class Bilibili(NewMessage): return self._do_get_category(post_type) def get_tags(self, raw_post: RawPost) -> list[Tag]: - return [ - *map( - lambda tp: tp["topic_name"], - raw_post["display"]["topic_info"]["topic_details"], - ) - ] + return [*(tp["topic_name"] for tp in raw_post["display"]["topic_info"]["topic_details"])] def _get_info(self, post_type: Category, card) -> tuple[str, list]: if post_type == 1: @@ -178,24 +167,16 @@ class Bilibili(NewMessage): url = "" if post_type == 1: # 一般动态 - url = "https://t.bilibili.com/{}".format( - raw_post["desc"]["dynamic_id_str"] - ) + url = "https://t.bilibili.com/{}".format(raw_post["desc"]["dynamic_id_str"]) elif post_type == 2: # 专栏文章 - url = "https://www.bilibili.com/read/cv{}".format( - raw_post["desc"]["rid"] - ) + url = "https://www.bilibili.com/read/cv{}".format(raw_post["desc"]["rid"]) elif post_type == 3: # 视频 - url = "https://www.bilibili.com/video/{}".format( - raw_post["desc"]["bvid"] - ) + url = "https://www.bilibili.com/video/{}".format(raw_post["desc"]["bvid"]) elif post_type == 4: # 纯文字 - url = "https://t.bilibili.com/{}".format( - raw_post["desc"]["dynamic_id_str"] - ) + url = "https://t.bilibili.com/{}".format(raw_post["desc"]["dynamic_id_str"]) text, pic = self._get_info(post_type, card_content) elif post_type == 5: # 转发 @@ -261,10 +242,7 @@ class Bilibililive(StatusChange): def get_live_action(self, old_info: Self) -> "Bilibililive.LiveAction": status = Bilibililive.LiveStatus action = Bilibililive.LiveAction - if ( - old_info.live_status in [status.OFF, status.CYCLE] - and self.live_status == status.ON - ): + if old_info.live_status in [status.OFF, status.CYCLE] and self.live_status == status.ON: return action.TURN_ON elif old_info.live_status == status.ON and self.live_status in [ status.OFF, @@ -281,12 +259,8 @@ class Bilibililive(StatusChange): return action.OFF @classmethod - async def get_target_name( - cls, client: AsyncClient, target: Target - ) -> Optional[str]: - res = await client.get( - "https://api.bilibili.com/x/web-interface/card", params={"mid": target} - ) + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: + res = await client.get("https://api.bilibili.com/x/web-interface/card", params={"mid": target}) res_data = json.loads(res.text) if res_data["code"]: return None @@ -382,9 +356,7 @@ class BilibiliBangumi(StatusChange): _url = "https://api.bilibili.com/pgc/review/user" @classmethod - async def get_target_name( - cls, client: AsyncClient, target: Target - ) -> Optional[str]: + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: res = await client.get(cls._url, params={"media_id": target}) res_data = res.json() if res_data["code"]: diff --git a/nonebot_bison/platform/ff14.py b/nonebot_bison/platform/ff14.py index 61ebc24..c7af6d4 100644 --- a/nonebot_bison/platform/ff14.py +++ b/nonebot_bison/platform/ff14.py @@ -1,15 +1,14 @@ -from typing import Any, Optional +from typing import Any from httpx import AsyncClient from ..post import Post -from ..types import RawPost, Target from ..utils import scheduler from .platform import NewMessage +from ..types import Target, RawPost class FF14(NewMessage): - categories = {} platform_name = "ff14" name = "最终幻想XIV官方公告" @@ -21,9 +20,7 @@ class FF14(NewMessage): has_target = False @classmethod - async def get_target_name( - cls, client: AsyncClient, target: Target - ) -> Optional[str]: + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: return "最终幻想XIV官方公告" async def get_sub_list(self, _) -> list[RawPost]: diff --git a/nonebot_bison/platform/mcbbsnews.py b/nonebot_bison/platform/mcbbsnews.py index f8020ab..1784698 100644 --- a/nonebot_bison/platform/mcbbsnews.py +++ b/nonebot_bison/platform/mcbbsnews.py @@ -2,15 +2,15 @@ import re import time import traceback -from bs4 import BeautifulSoup, Tag from httpx import AsyncClient from nonebot.log import logger +from bs4 import Tag, BeautifulSoup from nonebot.plugin import require from ..post import Post -from ..types import Category, RawPost, Target +from ..types import Target, RawPost, Category from ..utils import SchedulerConfig, http_client -from .platform import CategoryNotRecognize, CategoryNotSupport, NewMessage +from .platform import NewMessage, CategoryNotSupport, CategoryNotRecognize class McbbsnewsSchedConf(SchedulerConfig): @@ -134,9 +134,9 @@ class McbbsNews(NewMessage): if categoty_name in category_values: category_id = category_keys[category_values.index(categoty_name)] elif categoty_name in known_category_values: - raise CategoryNotSupport("McbbsNews订阅暂不支持 {}".format(categoty_name)) + raise CategoryNotSupport(f"McbbsNews订阅暂不支持 {categoty_name}") else: - raise CategoryNotRecognize("Mcbbsnews订阅尚未识别 {}".format(categoty_name)) + raise CategoryNotRecognize(f"Mcbbsnews订阅尚未识别 {categoty_name}") return category_id async def parse(self, post: RawPost) -> Post: @@ -170,7 +170,7 @@ class McbbsNews(NewMessage): 一般而言每条新闻的长度都很可观,图片生成时间比较喜人 """ require("nonebot_plugin_htmlrender") - from nonebot_plugin_htmlrender import capture_element, text_to_pic + from nonebot_plugin_htmlrender import text_to_pic, capture_element try: assert url @@ -181,7 +181,7 @@ class McbbsNews(NewMessage): device_scale_factor=3, ) assert pic_data - except: + except Exception: err_info = traceback.format_exc() logger.warning(f"渲染错误:{err_info}") diff --git a/nonebot_bison/platform/ncm.py b/nonebot_bison/platform/ncm.py index 4688f22..34883f7 100644 --- a/nonebot_bison/platform/ncm.py +++ b/nonebot_bison/platform/ncm.py @@ -1,23 +1,21 @@ import re -from typing import Any, Optional +from typing import Any from httpx import AsyncClient from ..post import Post -from ..types import ApiError, RawPost, Target -from ..utils import SchedulerConfig from .platform import NewMessage +from ..utils import SchedulerConfig +from ..types import Target, RawPost, ApiError class NcmSchedConf(SchedulerConfig): - name = "music.163.com" schedule_type = "interval" schedule_setting = {"minutes": 1} class NcmArtist(NewMessage): - categories = {} platform_name = "ncm-artist" enable_tag = False @@ -29,11 +27,9 @@ class NcmArtist(NewMessage): parse_target_promot = "请输入歌手主页(包含数字ID)的链接" @classmethod - async def get_target_name( - cls, client: AsyncClient, target: Target - ) -> Optional[str]: + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: res = await client.get( - "https://music.163.com/api/artist/albums/{}".format(target), + f"https://music.163.com/api/artist/albums/{target}", headers={"Referer": "https://music.163.com/"}, ) res_data = res.json() @@ -45,16 +41,14 @@ class NcmArtist(NewMessage): async def parse_target(cls, target_text: str) -> Target: if re.match(r"^\d+$", target_text): return Target(target_text) - elif match := re.match( - r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text - ): + elif match := re.match(r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text): return Target(match.group(1)) else: raise cls.ParseTargetException() async def get_sub_list(self, target: Target) -> list[RawPost]: res = await self.client.get( - "https://music.163.com/api/artist/albums/{}".format(target), + f"https://music.163.com/api/artist/albums/{target}", headers={"Referer": "https://music.163.com/"}, ) res_data = res.json() @@ -74,13 +68,10 @@ class NcmArtist(NewMessage): target_name = raw_post["artist"]["name"] pics = [raw_post["picUrl"]] url = "https://music.163.com/#/album?id={}".format(raw_post["id"]) - return Post( - "ncm-artist", text=text, url=url, pics=pics, target_name=target_name - ) + return Post("ncm-artist", text=text, url=url, pics=pics, target_name=target_name) class NcmRadio(NewMessage): - categories = {} platform_name = "ncm-radio" enable_tag = False @@ -92,9 +83,7 @@ class NcmRadio(NewMessage): parse_target_promot = "请输入主播电台主页(包含数字ID)的链接" @classmethod - async def get_target_name( - cls, client: AsyncClient, target: Target - ) -> Optional[str]: + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: res = await client.post( "http://music.163.com/api/dj/program/byradio", headers={"Referer": "https://music.163.com/"}, @@ -109,9 +98,7 @@ class NcmRadio(NewMessage): async def parse_target(cls, target_text: str) -> Target: if re.match(r"^\d+$", target_text): return Target(target_text) - elif match := re.match( - r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text - ): + elif match := re.match(r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text): return Target(match.group(1)) else: raise cls.ParseTargetException() diff --git a/nonebot_bison/platform/platform.py b/nonebot_bison/platform/platform.py index 283bf44..9ec9073 100644 --- a/nonebot_bison/platform/platform.py +++ b/nonebot_bison/platform/platform.py @@ -1,29 +1,32 @@ -import json import ssl +import json import time import typing +from typing import Any +from dataclasses import dataclass from abc import ABC, abstractmethod from collections import defaultdict -from dataclasses import dataclass -from typing import Any, Collection, Optional, Type +from collections.abc import Collection import httpx from httpx import AsyncClient from nonebot.log import logger from nonebot_plugin_saa import PlatformTarget -from ..plugin_config import plugin_config from ..post import Post -from ..types import Category, RawPost, Tag, Target, UserSubInfo +from ..plugin_config import plugin_config from ..utils import ProcessContext, SchedulerConfig +from ..types import Tag, Target, RawPost, Category, UserSubInfo class CategoryNotSupport(Exception): - "raise in get_category, when you know the category of the post but don't want to support it or don't support its parsing yet" + """raise in get_category, when you know the category of the post + but don't want to support it or don't support its parsing yet + """ class CategoryNotRecognize(Exception): - "raise in get_category, when you don't know the category of post" + """raise in get_category, when you don't know the category of post""" class RegistryMeta(type): @@ -42,7 +45,6 @@ class RegistryMeta(type): class PlatformMeta(RegistryMeta): - categories: dict[Category, str] store: dict[Target, Any] @@ -60,8 +62,7 @@ class PlatformABCMeta(PlatformMeta, ABC): class Platform(metaclass=PlatformABCMeta, base=True): - - scheduler: Type[SchedulerConfig] + scheduler: type[SchedulerConfig] ctx: ProcessContext is_common: bool enabled: bool @@ -70,16 +71,14 @@ class Platform(metaclass=PlatformABCMeta, base=True): categories: dict[Category, str] enable_tag: bool platform_name: str - parse_target_promot: Optional[str] = None - registry: list[Type["Platform"]] + parse_target_promot: str | None = None + registry: list[type["Platform"]] client: AsyncClient reverse_category: dict[str, Category] @classmethod @abstractmethod - async def get_target_name( - cls, client: AsyncClient, target: Target - ) -> Optional[str]: + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: ... @abstractmethod @@ -95,11 +94,7 @@ class Platform(metaclass=PlatformABCMeta, base=True): return await self.fetch_new_post(target, users) except httpx.RequestError as err: if plugin_config.bison_show_network_warning: - logger.warning( - "network connection error: {}, url: {}".format( - type(err), err.request.url - ) - ) + logger.warning(f"network connection error: {type(err)}, url: {err.request.url}") return [] except ssl.SSLError as err: if plugin_config.bison_show_network_warning: @@ -130,7 +125,7 @@ class Platform(metaclass=PlatformABCMeta, base=True): return Target(target_string) @abstractmethod - def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]: + def get_tags(self, raw_post: RawPost) -> Collection[Tag] | None: "Return Tag list of given RawPost" @classmethod @@ -201,9 +196,7 @@ class Platform(metaclass=PlatformABCMeta, base=True): ) -> list[tuple[PlatformTarget, list[Post]]]: res: list[tuple[PlatformTarget, list[Post]]] = [] for user, cats, required_tags in users: - user_raw_post = await self.filter_user_custom( - new_posts, cats, required_tags - ) + user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags) user_post: list[Post] = [] for raw_post in user_raw_post: user_post.append(await self.do_parse(raw_post)) @@ -211,7 +204,7 @@ class Platform(metaclass=PlatformABCMeta, base=True): return res @abstractmethod - def get_category(self, post: RawPost) -> Optional[Category]: + def get_category(self, post: RawPost) -> Category | None: "Return category of given Rawpost" raise NotImplementedError() @@ -221,7 +214,7 @@ class MessageProcess(Platform, abstract=True): def __init__(self, ctx: ProcessContext, client: AsyncClient): super().__init__(ctx, client) - self.parse_cache: dict[Any, Post] = dict() + self.parse_cache: dict[Any, Post] = {} @abstractmethod def get_id(self, post: RawPost) -> Any: @@ -246,7 +239,7 @@ class MessageProcess(Platform, abstract=True): "Get post list of the given target" @abstractmethod - def get_date(self, post: RawPost) -> Optional[int]: + def get_date(self, post: RawPost) -> int | None: "Get post timestamp and return, return None if can't get the time" async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]: @@ -286,9 +279,7 @@ class NewMessage(MessageProcess, abstract=True): inited: bool exists_posts: set[Any] - async def filter_common_with_diff( - self, target: Target, raw_post_list: list[RawPost] - ) -> list[RawPost]: + async def filter_common_with_diff(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]: filtered_post = await self.filter_common(raw_post_list) store = self.get_stored_data(target) or self.MessageStorage(False, set()) res = [] @@ -297,11 +288,7 @@ class NewMessage(MessageProcess, abstract=True): for raw_post in filtered_post: post_id = self.get_id(raw_post) store.exists_posts.add(post_id) - logger.info( - "init {}-{} with {}".format( - self.platform_name, target, store.exists_posts - ) - ) + logger.info(f"init {self.platform_name}-{target} with {store.exists_posts}") store.inited = True else: for raw_post in filtered_post: @@ -400,12 +387,11 @@ class SimplePost(MessageProcess, abstract=True): return res -def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]: - +def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]: if typing.TYPE_CHECKING: class NoTargetGroup(Platform, abstract=True): - platform_list: list[Type[Platform]] + platform_list: list[type[Platform]] platform_obj_list: list[Platform] DUMMY_STR = "_DUMMY" @@ -418,24 +404,18 @@ def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]: for platform in platform_list: if platform.has_target: - raise RuntimeError( - "Platform {} should have no target".format(platform.name) - ) + raise RuntimeError(f"Platform {platform.name} should have no target") if name == DUMMY_STR: name = platform.name elif name != platform.name: - raise RuntimeError("Platform name for {} not fit".format(platform_name)) + raise RuntimeError(f"Platform name for {platform_name} not fit") platform_category_key_set = set(platform.categories.keys()) if platform_category_key_set & categories_keys: - raise RuntimeError( - "Platform categories for {} duplicate".format(platform_name) - ) + raise RuntimeError(f"Platform categories for {platform_name} duplicate") categories_keys |= platform_category_key_set categories.update(platform.categories) if platform.scheduler != scheduler: - raise RuntimeError( - "Platform scheduler for {} not fit".format(platform_name) - ) + raise RuntimeError(f"Platform scheduler for {platform_name} not fit") def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient): Platform.__init__(self, ctx, client) @@ -444,15 +424,13 @@ def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]: self.platform_obj_list.append(platform_class(ctx, client)) def __str__(self: "NoTargetGroup") -> str: - return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]" + return "[" + " ".join(x.name for x in self.platform_list) + "]" @classmethod async def get_target_name(cls, client: AsyncClient, target: Target): return await platform_list[0].get_target_name(client, target) - async def fetch_new_post( - self: "NoTargetGroup", target: Target, users: list[UserSubInfo] - ): + async def fetch_new_post(self: "NoTargetGroup", target: Target, users: list[UserSubInfo]): res = defaultdict(list) for platform in self.platform_obj_list: platform_res = await platform.fetch_new_post(target=target, users=users) diff --git a/nonebot_bison/platform/rss.py b/nonebot_bison/platform/rss.py index cbcddd4..94c584c 100644 --- a/nonebot_bison/platform/rss.py +++ b/nonebot_bison/platform/rss.py @@ -1,10 +1,10 @@ import calendar import time -from typing import Any, Optional +from typing import Any import feedparser -from bs4 import BeautifulSoup as bs from httpx import AsyncClient +from bs4 import BeautifulSoup as bs from ..post import Post from ..types import RawPost, Target @@ -20,7 +20,6 @@ class RssSchedConf(SchedulerConfig): class Rss(NewMessage): - categories = {} enable_tag = False platform_name = "rss" @@ -31,9 +30,7 @@ class Rss(NewMessage): has_target = True @classmethod - async def get_target_name( - cls, client: AsyncClient, target: Target - ) -> Optional[str]: + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: res = await client.get(target, timeout=10.0) feed = feedparser.parse(res.text) return feed["feed"]["title"] @@ -69,7 +66,7 @@ class Rss(NewMessage): else: text = f"{title}\n\n{desc}" - pics = list(map(lambda x: x.attrs["src"], soup("img"))) + pics = [x.attrs["src"] for x in soup("img")] if raw_post.get("media_content"): for media in raw_post["media_content"]: if media.get("medium") == "image" and media.get("url"): diff --git a/nonebot_bison/platform/weibo.py b/nonebot_bison/platform/weibo.py index 2fccea2..e8d4197 100644 --- a/nonebot_bison/platform/weibo.py +++ b/nonebot_bison/platform/weibo.py @@ -1,17 +1,16 @@ -import json import re -from collections.abc import Callable +import json +from typing import Any from datetime import datetime -from typing import Any, Optional -from bs4 import BeautifulSoup as bs from httpx import AsyncClient from nonebot.log import logger +from bs4 import BeautifulSoup as bs from ..post import Post -from ..types import * -from ..utils import SchedulerConfig, http_client from .platform import NewMessage +from ..utils import SchedulerConfig, http_client +from ..types import Tag, Target, RawPost, ApiError, Category class WeiboSchedConf(SchedulerConfig): @@ -21,7 +20,6 @@ class WeiboSchedConf(SchedulerConfig): class Weibo(NewMessage): - categories = { 1: "转发", 2: "视频", @@ -38,13 +36,9 @@ class Weibo(NewMessage): parse_target_promot = "请输入用户主页(包含数字UID)的链接" @classmethod - async def get_target_name( - cls, client: AsyncClient, target: Target - ) -> Optional[str]: + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: param = {"containerid": "100505" + target} - res = await client.get( - "https://m.weibo.cn/api/container/getIndex", params=param - ) + res = await client.get("https://m.weibo.cn/api/container/getIndex", params=param) res_dict = json.loads(res.text) if res_dict.get("ok") == 1: return res_dict["data"]["userInfo"]["screen_name"] @@ -63,13 +57,14 @@ class Weibo(NewMessage): async def get_sub_list(self, target: Target) -> list[RawPost]: params = {"containerid": "107603" + target} - res = await self.client.get( - "https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0 - ) + res = await self.client.get("https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0) res_data = json.loads(res.text) if not res_data["ok"] and res_data["msg"] != "这里还没有内容": raise ApiError(res.request.url) - custom_filter: Callable[[RawPost], bool] = lambda d: d["card_type"] == 9 + + def custom_filter(d: RawPost) -> bool: + return d["card_type"] == 9 + return list(filter(custom_filter, res_data["data"]["cards"])) def get_id(self, post: RawPost) -> Any: @@ -79,44 +74,32 @@ class Weibo(NewMessage): return raw_post["card_type"] == 9 def get_date(self, raw_post: RawPost) -> float: - created_time = datetime.strptime( - raw_post["mblog"]["created_at"], "%a %b %d %H:%M:%S %z %Y" - ) + created_time = datetime.strptime(raw_post["mblog"]["created_at"], "%a %b %d %H:%M:%S %z %Y") return created_time.timestamp() - def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]: + def get_tags(self, raw_post: RawPost) -> list[Tag] | None: "Return Tag list of given RawPost" text = raw_post["mblog"]["text"] soup = bs(text, "html.parser") - res = list( - map( - lambda x: x[1:-1], - filter( - lambda s: s[0] == "#" and s[-1] == "#", - map(lambda x: x.text, soup.find_all("span", class_="surl-text")), - ), + res = [ + x[1:-1] + for x in filter( + lambda s: s[0] == "#" and s[-1] == "#", + (x.text for x in soup.find_all("span", class_="surl-text")), ) - ) - super_topic_img = soup.find( - "img", src=re.compile(r"timeline_card_small_super_default") - ) + ] + super_topic_img = soup.find("img", src=re.compile(r"timeline_card_small_super_default")) if super_topic_img: try: - res.append( - super_topic_img.parent.parent.find("span", class_="surl-text").text # type: ignore - + "超话" - ) - except: - logger.info("super_topic extract error: {}".format(text)) + res.append(super_topic_img.parent.parent.find("span", class_="surl-text").text + "超话") # type: ignore + except Exception: + logger.info(f"super_topic extract error: {text}") return res def get_category(self, raw_post: RawPost) -> Category: if raw_post["mblog"].get("retweeted_status"): return Category(1) - elif ( - raw_post["mblog"].get("page_info") - and raw_post["mblog"]["page_info"].get("type") == "video" - ): + elif raw_post["mblog"].get("page_info") and raw_post["mblog"]["page_info"].get("type") == "video": return Category(2) elif raw_post["mblog"].get("pics"): return Category(3) @@ -129,7 +112,8 @@ class Weibo(NewMessage): async def parse(self, raw_post: RawPost) -> Post: header = { - "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng," + "*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", "accept-language": "zh-CN,zh;q=0.9", "authority": "m.weibo.cn", "cache-control": "max-age=0", @@ -147,26 +131,16 @@ class Weibo(NewMessage): retweeted = True pic_num = info["retweeted_status"]["pic_num"] if retweeted else info["pic_num"] if info["isLongText"] or pic_num > 9: - res = await self.client.get( - "https://m.weibo.cn/detail/{}".format(info["mid"]), headers=header - ) + res = await self.client.get(f"https://m.weibo.cn/detail/{info['mid']}", headers=header) try: match = re.search(r'"status": ([\s\S]+),\s+"call"', res.text) assert match full_json_text = match.group(1) info = json.loads(full_json_text) - except: - logger.info( - "detail message error: https://m.weibo.cn/detail/{}".format( - info["mid"] - ) - ) + except Exception: + logger.info(f"detail message error: https://m.weibo.cn/detail/{info['mid']}") parsed_text = self._get_text(info["text"]) - raw_pics_list = ( - info["retweeted_status"].get("pics", []) - if retweeted - else info.get("pics", []) - ) + raw_pics_list = info["retweeted_status"].get("pics", []) if retweeted else info.get("pics", []) pic_urls = [img["large"]["url"] for img in raw_pics_list] pics = [] for pic_url in pic_urls: @@ -174,7 +148,7 @@ class Weibo(NewMessage): res = await client.get(pic_url) res.raise_for_status() pics.append(res.content) - detail_url = "https://weibo.com/{}/{}".format(info["user"]["id"], info["bid"]) + detail_url = f"https://weibo.com/{info['user']['id']}/{info['bid']}" # return parsed_text, detail_url, pic_urls return Post( "weibo", diff --git a/nonebot_bison/plugin_config.py b/nonebot_bison/plugin_config.py index cb35ed0..041b38a 100644 --- a/nonebot_bison/plugin_config.py +++ b/nonebot_bison/plugin_config.py @@ -1,11 +1,8 @@ -from typing import Optional - import nonebot from pydantic import BaseSettings class PlugConfig(BaseSettings): - bison_config_path: str = "" bison_use_pic: bool = False bison_init_filter: bool = True @@ -17,8 +14,12 @@ class PlugConfig(BaseSettings): bison_use_pic_merge: int = 0 # 多图片时启用图片合并转发(仅限群) # 0:不启用;1:首条消息单独发送,剩余照片合并转发;2以及以上:所有消息全部合并转发 bison_resend_times: int = 0 - bison_proxy: Optional[str] - bison_ua: str = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.103 Safari/537.36" + bison_proxy: str | None + bison_ua: str = ( + "Mozilla/5.0 (X11; Linux x86_64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/51.0.2704.103 Safari/537.36" + ) bison_show_network_warning: bool = True class Config: diff --git a/nonebot_bison/post/__init__.py b/nonebot_bison/post/__init__.py index ff93bec..3900f47 100644 --- a/nonebot_bison/post/__init__.py +++ b/nonebot_bison/post/__init__.py @@ -1,3 +1,3 @@ from .post import Post -__all__ = ["Post", "CustomPost"] +__all__ = ["Post"] diff --git a/nonebot_bison/post/abstract_post.py b/nonebot_bison/post/abstract_post.py index ee055d5..a8d88de 100644 --- a/nonebot_bison/post/abstract_post.py +++ b/nonebot_bison/post/abstract_post.py @@ -1,7 +1,6 @@ -from abc import abstractmethod -from dataclasses import dataclass, field from functools import reduce -from typing import Optional +from abc import abstractmethod +from dataclasses import field, dataclass from nonebot_plugin_saa import MessageFactory, MessageSegmentFactory @@ -25,12 +24,12 @@ class BasePost: class OptionalMixin: # Because of https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses - override_use_pic: Optional[bool] = None + override_use_pic: bool | None = None compress: bool = False extra_msg: list[MessageFactory] = field(default_factory=list) def _use_pic(self): - if not self.override_use_pic is None: + if self.override_use_pic is not None: return self.override_use_pic return plugin_config.bison_use_pic @@ -44,13 +43,9 @@ class AbstractPost(OptionalMixin, BasePost): msg_segments = await self.generate_text_messages() if msg_segments: if self.compress: - msgs = [ - reduce(lambda x, y: x.append(y), msg_segments, MessageFactory([])) - ] + msgs = [reduce(lambda x, y: x.append(y), msg_segments, MessageFactory([]))] else: - msgs = list( - map(lambda msg_segment: MessageFactory([msg_segment]), msg_segments) - ) + msgs = [MessageFactory([msg_segment]) for msg_segment in msg_segments] else: msgs = [] msgs.extend(self.extra_msg) diff --git a/nonebot_bison/post/custom_post.py b/nonebot_bison/post/custom_post.py index 951a9c8..4921bc2 100644 --- a/nonebot_bison/post/custom_post.py +++ b/nonebot_bison/post/custom_post.py @@ -1,19 +1,17 @@ -from dataclasses import dataclass, field -from typing import Optional +from dataclasses import field, dataclass -from nonebot.adapters.onebot.v11 import MessageSegment from nonebot.log import logger from nonebot.plugin import require -from nonebot_plugin_saa import Image, MessageFactory, MessageSegmentFactory, Text +from nonebot.adapters.onebot.v11 import MessageSegment +from nonebot_plugin_saa import Text, Image, MessageSegmentFactory -from .abstract_post import AbstractPost, BasePost +from .abstract_post import BasePost, AbstractPost @dataclass class _CustomPost(BasePost): - ms_factories: list[MessageSegmentFactory] = field(default_factory=list) - css_path: Optional[str] = None # 模板文件所用css路径 + css_path: str | None = None # 模板文件所用css路径 async def generate_text_messages(self) -> list[MessageSegmentFactory]: return self.ms_factories @@ -31,15 +29,13 @@ class _CustomPost(BasePost): for message_segment in self.ms_factories: match message_segment: case Text(data={"text": text}): - md += "{}