From a48ea0e947a14190ee601ba813f1efd47a04dfff Mon Sep 17 00:00:00 2001 From: suyiiyii Date: Thu, 5 Dec 2024 16:08:42 +0800 Subject: [PATCH] =?UTF-8?q?:bug:=20=E4=BF=AE=E5=A4=8D=20cookie=20=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=20type=20hint=20(#658)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot_bison/admin_page/api.py | 14 ++++++++++---- nonebot_bison/config/db_config.py | 4 ++++ nonebot_bison/config/subs_io/nbesf_model/v3.py | 7 +++++-- nonebot_bison/config/subs_io/subs_io.py | 2 +- nonebot_bison/config/subs_io/utils.py | 6 ++++-- nonebot_bison/script/cli.py | 2 +- nonebot_bison/sub_manager/add_cookie_target.py | 6 ++---- nonebot_bison/sub_manager/utils.py | 8 ++++---- nonebot_bison/utils/site.py | 8 ++++---- tests/subs_io/test_cli.py | 2 +- tests/subs_io/test_subs_io.py | 6 +++--- 11 files changed, 39 insertions(+), 26 deletions(-) diff --git a/nonebot_bison/admin_page/api.py b/nonebot_bison/admin_page/api.py index 14111c8..da8208c 100644 --- a/nonebot_bison/admin_page/api.py +++ b/nonebot_bison/admin_page/api.py @@ -22,6 +22,7 @@ from ..utils.site import CookieClientManager, site_manager, is_cookie_client_man from ..config import NoSuchUserException, NoSuchTargetException, NoSuchSubscribeException, config from .types import ( Cookie, + Target, TokenResp, GlobalConf, SiteConfig, @@ -211,7 +212,7 @@ async def update_weigth_config(platformName: str, target: str, weight_config: We @router.get("/cookie", dependencies=[Depends(check_is_superuser)]) -async def get_cookie(site_name: str = None, target: str = None) -> list[Cookie]: +async def get_cookie(site_name: str | None = None, target: str | None = None) -> list[Cookie]: cookies_in_db = await config.get_cookie(site_name, is_anonymous=False) return [ Cookie( @@ -250,7 +251,12 @@ async def get_cookie_target( cookie_targets = await config.get_cookie_target() # TODO: filter in SQL return [ - x + CookieTarget( + target=Target( + platform_name=x.target.platform_name, target_name=x.target.target_name, target=x.target.target + ), + cookie_id=x.cookie.id, + ) for x in cookie_targets if (site_name is None or x.cookie.site_name == site_name) and (target is None or x.target.target == target) @@ -259,13 +265,13 @@ async def get_cookie_target( @router.post("/cookie_target", dependencies=[Depends(check_is_superuser)]) -async def add_cookie_target(platform_name: str, target: str, cookie_id: int) -> StatusResp: +async def add_cookie_target(platform_name: str, target: T_Target, cookie_id: int) -> StatusResp: await config.add_cookie_target(target, platform_name, cookie_id) return StatusResp(ok=True, msg="") @router.delete("/cookie_target", dependencies=[Depends(check_is_superuser)]) -async def del_cookie_target(platform_name: str, target: str, cookie_id: int) -> StatusResp: +async def del_cookie_target(platform_name: str, target: T_Target, cookie_id: int) -> StatusResp: await config.delete_cookie_target(target, platform_name, cookie_id) return StatusResp(ok=True, msg="") diff --git a/nonebot_bison/config/db_config.py b/nonebot_bison/config/db_config.py index 7451292..c76078c 100644 --- a/nonebot_bison/config/db_config.py +++ b/nonebot_bison/config/db_config.py @@ -288,6 +288,8 @@ class DBConfig: async def get_cookie_by_id(self, cookie_id: int) -> Cookie: async with create_session() as sess: cookie = await sess.scalar(select(Cookie).where(Cookie.id == cookie_id)) + if not cookie: + raise NoSuchTargetException(f"cookie {cookie_id} not found") return cookie async def add_cookie(self, cookie: Cookie) -> int: @@ -317,6 +319,8 @@ class DBConfig: .outerjoin(CookieTarget) .options(selectinload(Cookie.targets)) ) + if not cookie: + raise NoSuchTargetException(f"cookie {cookie_id} not found") if len(cookie.targets) > 0: raise Exception(f"cookie {cookie.id} in use") await sess.execute(delete(Cookie).where(Cookie.id == cookie_id)) diff --git a/nonebot_bison/config/subs_io/nbesf_model/v3.py b/nonebot_bison/config/subs_io/nbesf_model/v3.py index a2d5e42..c6f1c03 100644 --- a/nonebot_bison/config/subs_io/nbesf_model/v3.py +++ b/nonebot_bison/config/subs_io/nbesf_model/v3.py @@ -8,8 +8,11 @@ from pydantic import BaseModel from nonebot_plugin_saa.registries import AllSupportedPlatformTarget from nonebot.compat import PYDANTIC_V2, ConfigDict, model_dump, type_validate_json, type_validate_python +from nonebot_bison.types import Tag +from nonebot_bison.types import Category +from nonebot_bison.types import Target as T_Target + from ..utils import NBESFParseErr -from ....types import Tag, Category from .base import NBESFBase, SubReceipt from ...db_model import Cookie as DBCookie from ...db_config import SubscribeDupException, config @@ -114,7 +117,7 @@ async def magic_cookie_gen(nbesf_data: SubGroup): new_cookie = DBCookie(**model_dump(cookie, exclude={"targets"})) cookie_id = await config.add_cookie(new_cookie) for target in cookie.targets: - await config.add_cookie_target(target.target, target.platform_name, cookie_id) + await config.add_cookie_target(T_Target(target.target), target.platform_name, cookie_id) except Exception as e: logger.error(f"!添加 Cookie 条目 {repr(cookie)} 失败: {repr(e)}") else: diff --git a/nonebot_bison/config/subs_io/subs_io.py b/nonebot_bison/config/subs_io/subs_io.py index 9b9ce48..c871096 100644 --- a/nonebot_bison/config/subs_io/subs_io.py +++ b/nonebot_bison/config/subs_io/subs_io.py @@ -65,7 +65,7 @@ async def subscribes_export(selector: Callable[[Select], Select]) -> v3.SubGroup target_payload = type_validate_python(v3.Target, cookie_target.target) cookie_target_dict[cookie_target.cookie].append(target_payload) - def cookie_transform(cookie: Cookie, targets: [Target]) -> v3.Cookie: + def cookie_transform(cookie: Cookie, targets: list[v3.Target]) -> v3.Cookie: cookie_dict = row2dict(cookie) cookie_dict["tags"] = cookie.tags cookie_dict["targets"] = targets diff --git a/nonebot_bison/config/subs_io/utils.py b/nonebot_bison/config/subs_io/utils.py index 1ba7558..21c3a5a 100644 --- a/nonebot_bison/config/subs_io/utils.py +++ b/nonebot_bison/config/subs_io/utils.py @@ -1,4 +1,6 @@ -from ..db_model import Model +from typing import Any + +from sqlalchemy.orm import DeclarativeBase class NBESFVerMatchErr(Exception): ... @@ -7,7 +9,7 @@ class NBESFVerMatchErr(Exception): ... class NBESFParseErr(Exception): ... -def row2dict(row: Model) -> dict: +def row2dict(row: DeclarativeBase) -> dict[str, Any]: d = {} for column in row.__table__.columns: d[column.name] = str(getattr(row, column.name)) diff --git a/nonebot_bison/script/cli.py b/nonebot_bison/script/cli.py index 790a7dd..f17e431 100644 --- a/nonebot_bison/script/cli.py +++ b/nonebot_bison/script/cli.py @@ -83,7 +83,7 @@ async def subs_export(path: Path, format: str): export_file = path / f"bison_subscribes_export_{int(time.time())}.{format}" logger.info("正在获取订阅信息...") - export_data: v2.SubGroup = await subscribes_export(lambda x: x) + export_data: v3.SubGroup = await subscribes_export(lambda x: x) with export_file.open("w", encoding="utf-8") as f: match format: diff --git a/nonebot_bison/sub_manager/add_cookie_target.py b/nonebot_bison/sub_manager/add_cookie_target.py index 8fd4780..919d0ca 100644 --- a/nonebot_bison/sub_manager/add_cookie_target.py +++ b/nonebot_bison/sub_manager/add_cookie_target.py @@ -26,8 +26,7 @@ def do_add_cookie_target(add_cookie_target_matcher: type[Matcher]): @add_cookie_target_matcher.got("target_idx", parameterless=[handle_cancel]) async def got_target_idx(state: T_State, target_idx: str = ArgPlainText()): try: - target_idx = int(target_idx) - state["target"] = state["sub_table"][target_idx] + state["target"] = state["sub_table"][int(target_idx)] state["site"] = platform_manager[state["target"]["platform_name"]].site except Exception: await add_cookie_target_matcher.reject("序号错误") @@ -57,8 +56,7 @@ def do_add_cookie_target(add_cookie_target_matcher: type[Matcher]): @add_cookie_target_matcher.got("cookie_idx", MessageTemplate("{_prompt}"), [handle_cancel]) async def got_cookie_idx(state: T_State, cookie_idx: str = ArgPlainText()): try: - cookie_idx = int(cookie_idx) - state["cookie"] = state["cookies"][cookie_idx - 1] + state["cookie"] = state["cookies"][int(cookie_idx) - 1] except Exception: await add_cookie_target_matcher.reject("序号错误") diff --git a/nonebot_bison/sub_manager/utils.py b/nonebot_bison/sub_manager/utils.py index 8126b9a..2418cfc 100644 --- a/nonebot_bison/sub_manager/utils.py +++ b/nonebot_bison/sub_manager/utils.py @@ -13,6 +13,7 @@ from nonebot_plugin_saa import PlatformTarget, extract_target from ..config import config from ..types import Category +from ..types import Target as T_Target from ..platform import platform_manager from ..plugin_config import plugin_config from ..utils.site import is_cookie_client_manager @@ -88,7 +89,7 @@ async def generate_sub_list_text( sub_list = [ sub for sub in sub_list - if is_cookie_client_manager(platform_manager.get(sub.target.platform_name).site.client_mgr) + if is_cookie_client_manager(platform_manager[sub.target.platform_name].site.client_mgr) ] if not sub_list: await matcher.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅") @@ -109,7 +110,7 @@ async def generate_sub_list_text( res += " {}".format(", ".join(sub.tags)) + "\n" if is_show_cookie: target_cookies = await config.get_cookie( - target=sub.target.target, site_name=platform.site.name, is_anonymous=False + target=T_Target(sub.target.target), site_name=platform.site.name, is_anonymous=False ) if target_cookies: res += " 关联的 Cookie:\n" @@ -126,6 +127,5 @@ async def only_allow_private( event: Event, matcher: type[Matcher], ): - # if not issubclass(PrivateMessageEvent, event.__class__): - if event.message_type != "private": + if not (hasattr(event, "message_type") and getattr(event, "message_type") == "private"): await matcher.finish("请在私聊中使用此命令") diff --git a/nonebot_bison/utils/site.py b/nonebot_bison/utils/site.py index 8dcd371..aecb00f 100644 --- a/nonebot_bison/utils/site.py +++ b/nonebot_bison/utils/site.py @@ -123,8 +123,8 @@ class CookieClientManager(ClientManager): async def _choose_cookie(self, target: Target | None) -> Cookie: """选择 cookie 的具体算法""" cookies = await config.get_cookie(self._site_name, target) - cookies = (cookie for cookie in cookies if cookie.last_usage + cookie.cd < datetime.now()) - cookie = min(cookies, key=lambda x: x.last_usage) + avaliable_cookies = (cookie for cookie in cookies if cookie.last_usage + cookie.cd < datetime.now()) + cookie = min(avaliable_cookies, key=lambda x: x.last_usage) return cookie async def get_client(self, target: Target | None) -> AsyncClient: @@ -183,8 +183,8 @@ class SiteMeta(type): cls._key = kwargs.get("key") elif not kwargs.get("abstract"): # this is the subclass - if hasattr(cls, "name"): - site_manager[cls.name] = cls + if "name" in namespace: + site_manager[namespace["name"]] = cls super().__init__(name, bases, namespace, **kwargs) diff --git a/tests/subs_io/test_cli.py b/tests/subs_io/test_cli.py index d74bac5..61fe271 100644 --- a/tests/subs_io/test_cli.py +++ b/tests/subs_io/test_cli.py @@ -78,7 +78,7 @@ async def test_subs_export(app: App, tmp_path: Path): cookie_name="test cookie", ) ) - await config.add_cookie_target("weibo_id", "weibo", cookie_id) + await config.add_cookie_target(TTarget("weibo_id"), "weibo", cookie_id) assert len(await config.list_subs_with_all_info()) == 3 diff --git a/tests/subs_io/test_subs_io.py b/tests/subs_io/test_subs_io.py index 700becf..5164672 100644 --- a/tests/subs_io/test_subs_io.py +++ b/tests/subs_io/test_subs_io.py @@ -16,7 +16,7 @@ async def test_subs_export(app: App, init_scheduler): await config.add_subscribe( TargetQQGroup(group_id=1232), - target=TTarget("weibo_id"), + target=TTarget(TTarget("weibo_id")), target_name="weibo_name", platform_name="weibo", cats=[], @@ -24,7 +24,7 @@ async def test_subs_export(app: App, init_scheduler): ) await config.add_subscribe( TargetQQGroup(group_id=2342), - target=TTarget("weibo_id"), + target=TTarget(TTarget("weibo_id")), target_name="weibo_name", platform_name="weibo", cats=[], @@ -45,7 +45,7 @@ async def test_subs_export(app: App, init_scheduler): cookie_name="test cookie", ) ) - await config.add_cookie_target("weibo_id", "weibo", cookie_id) + await config.add_cookie_target(TTarget("weibo_id"), "weibo", cookie_id) data = await config.list_subs_with_all_info() assert len(data) == 3