mirror of
				https://github.com/suyiiyii/nonebot-bison.git
				synced 2025-11-04 21:44:52 +08:00 
			
		
		
		
	🐛 修复 cookie 模块 type hint (#658)
This commit is contained in:
		
							parent
							
								
									23d945f8c7
								
							
						
					
					
						commit
						e496bf82e6
					
				@ -24,6 +24,7 @@ from ..utils.site import CookieClientManager, site_manager, is_cookie_client_man
 | 
				
			|||||||
from ..config import NoSuchUserException, NoSuchTargetException, NoSuchSubscribeException, config
 | 
					from ..config import NoSuchUserException, NoSuchTargetException, NoSuchSubscribeException, config
 | 
				
			||||||
from .types import (
 | 
					from .types import (
 | 
				
			||||||
    Cookie,
 | 
					    Cookie,
 | 
				
			||||||
 | 
					    Target,
 | 
				
			||||||
    TokenResp,
 | 
					    TokenResp,
 | 
				
			||||||
    GlobalConf,
 | 
					    GlobalConf,
 | 
				
			||||||
    SiteConfig,
 | 
					    SiteConfig,
 | 
				
			||||||
@ -213,7 +214,7 @@ async def update_weigth_config(platformName: str, target: str, weight_config: We
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@router.get("/cookie", dependencies=[Depends(check_is_superuser)])
 | 
					@router.get("/cookie", dependencies=[Depends(check_is_superuser)])
 | 
				
			||||||
async def get_cookie(site_name: str = None, target: str = None) -> list[Cookie]:
 | 
					async def get_cookie(site_name: str | None = None, target: str | None = None) -> list[Cookie]:
 | 
				
			||||||
    cookies_in_db = await config.get_cookie(site_name, is_anonymous=False)
 | 
					    cookies_in_db = await config.get_cookie(site_name, is_anonymous=False)
 | 
				
			||||||
    return [
 | 
					    return [
 | 
				
			||||||
        Cookie(
 | 
					        Cookie(
 | 
				
			||||||
@ -252,7 +253,12 @@ async def get_cookie_target(
 | 
				
			|||||||
    cookie_targets = await config.get_cookie_target()
 | 
					    cookie_targets = await config.get_cookie_target()
 | 
				
			||||||
    # TODO: filter in SQL
 | 
					    # TODO: filter in SQL
 | 
				
			||||||
    return [
 | 
					    return [
 | 
				
			||||||
        x
 | 
					        CookieTarget(
 | 
				
			||||||
 | 
					            target=Target(
 | 
				
			||||||
 | 
					                platform_name=x.target.platform_name, target_name=x.target.target_name, target=x.target.target
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            cookie_id=x.cookie.id,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        for x in cookie_targets
 | 
					        for x in cookie_targets
 | 
				
			||||||
        if (site_name is None or x.cookie.site_name == site_name)
 | 
					        if (site_name is None or x.cookie.site_name == site_name)
 | 
				
			||||||
        and (target is None or x.target.target == target)
 | 
					        and (target is None or x.target.target == target)
 | 
				
			||||||
@ -261,13 +267,13 @@ async def get_cookie_target(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@router.post("/cookie_target", dependencies=[Depends(check_is_superuser)])
 | 
					@router.post("/cookie_target", dependencies=[Depends(check_is_superuser)])
 | 
				
			||||||
async def add_cookie_target(platform_name: str, target: str, cookie_id: int) -> StatusResp:
 | 
					async def add_cookie_target(platform_name: str, target: T_Target, cookie_id: int) -> StatusResp:
 | 
				
			||||||
    await config.add_cookie_target(target, platform_name, cookie_id)
 | 
					    await config.add_cookie_target(target, platform_name, cookie_id)
 | 
				
			||||||
    return StatusResp(ok=True, msg="")
 | 
					    return StatusResp(ok=True, msg="")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@router.delete("/cookie_target", dependencies=[Depends(check_is_superuser)])
 | 
					@router.delete("/cookie_target", dependencies=[Depends(check_is_superuser)])
 | 
				
			||||||
async def del_cookie_target(platform_name: str, target: str, cookie_id: int) -> StatusResp:
 | 
					async def del_cookie_target(platform_name: str, target: T_Target, cookie_id: int) -> StatusResp:
 | 
				
			||||||
    await config.delete_cookie_target(target, platform_name, cookie_id)
 | 
					    await config.delete_cookie_target(target, platform_name, cookie_id)
 | 
				
			||||||
    return StatusResp(ok=True, msg="")
 | 
					    return StatusResp(ok=True, msg="")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -288,6 +288,8 @@ class DBConfig:
 | 
				
			|||||||
    async def get_cookie_by_id(self, cookie_id: int) -> Cookie:
 | 
					    async def get_cookie_by_id(self, cookie_id: int) -> Cookie:
 | 
				
			||||||
        async with create_session() as sess:
 | 
					        async with create_session() as sess:
 | 
				
			||||||
            cookie = await sess.scalar(select(Cookie).where(Cookie.id == cookie_id))
 | 
					            cookie = await sess.scalar(select(Cookie).where(Cookie.id == cookie_id))
 | 
				
			||||||
 | 
					            if not cookie:
 | 
				
			||||||
 | 
					                raise NoSuchTargetException(f"cookie {cookie_id} not found")
 | 
				
			||||||
            return cookie
 | 
					            return cookie
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def add_cookie(self, cookie: Cookie) -> int:
 | 
					    async def add_cookie(self, cookie: Cookie) -> int:
 | 
				
			||||||
@ -317,6 +319,8 @@ class DBConfig:
 | 
				
			|||||||
                .outerjoin(CookieTarget)
 | 
					                .outerjoin(CookieTarget)
 | 
				
			||||||
                .options(selectinload(Cookie.targets))
 | 
					                .options(selectinload(Cookie.targets))
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					            if not cookie:
 | 
				
			||||||
 | 
					                raise NoSuchTargetException(f"cookie {cookie_id} not found")
 | 
				
			||||||
            if len(cookie.targets) > 0:
 | 
					            if len(cookie.targets) > 0:
 | 
				
			||||||
                raise Exception(f"cookie {cookie.id} in use")
 | 
					                raise Exception(f"cookie {cookie.id} in use")
 | 
				
			||||||
            await sess.execute(delete(Cookie).where(Cookie.id == cookie_id))
 | 
					            await sess.execute(delete(Cookie).where(Cookie.id == cookie_id))
 | 
				
			||||||
 | 
				
			|||||||
@ -8,8 +8,11 @@ from pydantic import BaseModel
 | 
				
			|||||||
from nonebot_plugin_saa.registries import AllSupportedPlatformTarget
 | 
					from nonebot_plugin_saa.registries import AllSupportedPlatformTarget
 | 
				
			||||||
from nonebot.compat import PYDANTIC_V2, ConfigDict, model_dump, type_validate_json, type_validate_python
 | 
					from nonebot.compat import PYDANTIC_V2, ConfigDict, model_dump, type_validate_json, type_validate_python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from nonebot_bison.types import Tag
 | 
				
			||||||
 | 
					from nonebot_bison.types import Category
 | 
				
			||||||
 | 
					from nonebot_bison.types import Target as T_Target
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..utils import NBESFParseErr
 | 
					from ..utils import NBESFParseErr
 | 
				
			||||||
from ....types import Tag, Category
 | 
					 | 
				
			||||||
from .base import NBESFBase, SubReceipt
 | 
					from .base import NBESFBase, SubReceipt
 | 
				
			||||||
from ...db_model import Cookie as DBCookie
 | 
					from ...db_model import Cookie as DBCookie
 | 
				
			||||||
from ...db_config import SubscribeDupException, config
 | 
					from ...db_config import SubscribeDupException, config
 | 
				
			||||||
@ -114,7 +117,7 @@ async def magic_cookie_gen(nbesf_data: SubGroup):
 | 
				
			|||||||
            new_cookie = DBCookie(**model_dump(cookie, exclude={"targets"}))
 | 
					            new_cookie = DBCookie(**model_dump(cookie, exclude={"targets"}))
 | 
				
			||||||
            cookie_id = await config.add_cookie(new_cookie)
 | 
					            cookie_id = await config.add_cookie(new_cookie)
 | 
				
			||||||
            for target in cookie.targets:
 | 
					            for target in cookie.targets:
 | 
				
			||||||
                await config.add_cookie_target(target.target, target.platform_name, cookie_id)
 | 
					                await config.add_cookie_target(T_Target(target.target), target.platform_name, cookie_id)
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
            logger.error(f"!添加 Cookie 条目 {repr(cookie)} 失败: {repr(e)}")
 | 
					            logger.error(f"!添加 Cookie 条目 {repr(cookie)} 失败: {repr(e)}")
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
 | 
				
			|||||||
@ -65,7 +65,7 @@ async def subscribes_export(selector: Callable[[Select], Select]) -> v3.SubGroup
 | 
				
			|||||||
        target_payload = type_validate_python(v3.Target, cookie_target.target)
 | 
					        target_payload = type_validate_python(v3.Target, cookie_target.target)
 | 
				
			||||||
        cookie_target_dict[cookie_target.cookie].append(target_payload)
 | 
					        cookie_target_dict[cookie_target.cookie].append(target_payload)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def cookie_transform(cookie: Cookie, targets: [Target]) -> v3.Cookie:
 | 
					    def cookie_transform(cookie: Cookie, targets: list[v3.Target]) -> v3.Cookie:
 | 
				
			||||||
        cookie_dict = row2dict(cookie)
 | 
					        cookie_dict = row2dict(cookie)
 | 
				
			||||||
        cookie_dict["tags"] = cookie.tags
 | 
					        cookie_dict["tags"] = cookie.tags
 | 
				
			||||||
        cookie_dict["targets"] = targets
 | 
					        cookie_dict["targets"] = targets
 | 
				
			||||||
 | 
				
			|||||||
@ -1,4 +1,6 @@
 | 
				
			|||||||
from ..db_model import Model
 | 
					from typing import Any
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from sqlalchemy.orm import DeclarativeBase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class NBESFVerMatchErr(Exception): ...
 | 
					class NBESFVerMatchErr(Exception): ...
 | 
				
			||||||
@ -7,7 +9,7 @@ class NBESFVerMatchErr(Exception): ...
 | 
				
			|||||||
class NBESFParseErr(Exception): ...
 | 
					class NBESFParseErr(Exception): ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def row2dict(row: Model) -> dict:
 | 
					def row2dict(row: DeclarativeBase) -> dict[str, Any]:
 | 
				
			||||||
    d = {}
 | 
					    d = {}
 | 
				
			||||||
    for column in row.__table__.columns:
 | 
					    for column in row.__table__.columns:
 | 
				
			||||||
        d[column.name] = str(getattr(row, column.name))
 | 
					        d[column.name] = str(getattr(row, column.name))
 | 
				
			||||||
 | 
				
			|||||||
@ -83,7 +83,7 @@ async def subs_export(path: Path, format: str):
 | 
				
			|||||||
    export_file = path / f"bison_subscribes_export_{int(time.time())}.{format}"
 | 
					    export_file = path / f"bison_subscribes_export_{int(time.time())}.{format}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info("正在获取订阅信息...")
 | 
					    logger.info("正在获取订阅信息...")
 | 
				
			||||||
    export_data: v2.SubGroup = await subscribes_export(lambda x: x)
 | 
					    export_data: v3.SubGroup = await subscribes_export(lambda x: x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with export_file.open("w", encoding="utf-8") as f:
 | 
					    with export_file.open("w", encoding="utf-8") as f:
 | 
				
			||||||
        match format:
 | 
					        match format:
 | 
				
			||||||
 | 
				
			|||||||
@ -26,8 +26,7 @@ def do_add_cookie_target(add_cookie_target_matcher: type[Matcher]):
 | 
				
			|||||||
    @add_cookie_target_matcher.got("target_idx", parameterless=[handle_cancel])
 | 
					    @add_cookie_target_matcher.got("target_idx", parameterless=[handle_cancel])
 | 
				
			||||||
    async def got_target_idx(state: T_State, target_idx: str = ArgPlainText()):
 | 
					    async def got_target_idx(state: T_State, target_idx: str = ArgPlainText()):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            target_idx = int(target_idx)
 | 
					            state["target"] = state["sub_table"][int(target_idx)]
 | 
				
			||||||
            state["target"] = state["sub_table"][target_idx]
 | 
					 | 
				
			||||||
            state["site"] = platform_manager[state["target"]["platform_name"]].site
 | 
					            state["site"] = platform_manager[state["target"]["platform_name"]].site
 | 
				
			||||||
        except Exception:
 | 
					        except Exception:
 | 
				
			||||||
            await add_cookie_target_matcher.reject("序号错误")
 | 
					            await add_cookie_target_matcher.reject("序号错误")
 | 
				
			||||||
@ -57,8 +56,7 @@ def do_add_cookie_target(add_cookie_target_matcher: type[Matcher]):
 | 
				
			|||||||
    @add_cookie_target_matcher.got("cookie_idx", MessageTemplate("{_prompt}"), [handle_cancel])
 | 
					    @add_cookie_target_matcher.got("cookie_idx", MessageTemplate("{_prompt}"), [handle_cancel])
 | 
				
			||||||
    async def got_cookie_idx(state: T_State, cookie_idx: str = ArgPlainText()):
 | 
					    async def got_cookie_idx(state: T_State, cookie_idx: str = ArgPlainText()):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            cookie_idx = int(cookie_idx)
 | 
					            state["cookie"] = state["cookies"][int(cookie_idx) - 1]
 | 
				
			||||||
            state["cookie"] = state["cookies"][cookie_idx - 1]
 | 
					 | 
				
			||||||
        except Exception:
 | 
					        except Exception:
 | 
				
			||||||
            await add_cookie_target_matcher.reject("序号错误")
 | 
					            await add_cookie_target_matcher.reject("序号错误")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -13,6 +13,7 @@ from nonebot_plugin_saa import PlatformTarget, extract_target
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from ..config import config
 | 
					from ..config import config
 | 
				
			||||||
from ..types import Category
 | 
					from ..types import Category
 | 
				
			||||||
 | 
					from ..types import Target as T_Target
 | 
				
			||||||
from ..platform import platform_manager
 | 
					from ..platform import platform_manager
 | 
				
			||||||
from ..plugin_config import plugin_config
 | 
					from ..plugin_config import plugin_config
 | 
				
			||||||
from ..utils.site import is_cookie_client_manager
 | 
					from ..utils.site import is_cookie_client_manager
 | 
				
			||||||
@ -88,7 +89,7 @@ async def generate_sub_list_text(
 | 
				
			|||||||
        sub_list = [
 | 
					        sub_list = [
 | 
				
			||||||
            sub
 | 
					            sub
 | 
				
			||||||
            for sub in sub_list
 | 
					            for sub in sub_list
 | 
				
			||||||
            if is_cookie_client_manager(platform_manager.get(sub.target.platform_name).site.client_mgr)
 | 
					            if is_cookie_client_manager(platform_manager[sub.target.platform_name].site.client_mgr)
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
    if not sub_list:
 | 
					    if not sub_list:
 | 
				
			||||||
        await matcher.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅")
 | 
					        await matcher.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅")
 | 
				
			||||||
@ -109,7 +110,7 @@ async def generate_sub_list_text(
 | 
				
			|||||||
                    res += " {}".format(", ".join(sub.tags)) + "\n"
 | 
					                    res += " {}".format(", ".join(sub.tags)) + "\n"
 | 
				
			||||||
            if is_show_cookie:
 | 
					            if is_show_cookie:
 | 
				
			||||||
                target_cookies = await config.get_cookie(
 | 
					                target_cookies = await config.get_cookie(
 | 
				
			||||||
                    target=sub.target.target, site_name=platform.site.name, is_anonymous=False
 | 
					                    target=T_Target(sub.target.target), site_name=platform.site.name, is_anonymous=False
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                if target_cookies:
 | 
					                if target_cookies:
 | 
				
			||||||
                    res += "  关联的 Cookie:\n"
 | 
					                    res += "  关联的 Cookie:\n"
 | 
				
			||||||
@ -126,6 +127,5 @@ async def only_allow_private(
 | 
				
			|||||||
    event: Event,
 | 
					    event: Event,
 | 
				
			||||||
    matcher: type[Matcher],
 | 
					    matcher: type[Matcher],
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    # if not issubclass(PrivateMessageEvent, event.__class__):
 | 
					    if not (hasattr(event, "message_type") and getattr(event, "message_type") == "private"):
 | 
				
			||||||
    if event.message_type != "private":
 | 
					 | 
				
			||||||
        await matcher.finish("请在私聊中使用此命令")
 | 
					        await matcher.finish("请在私聊中使用此命令")
 | 
				
			||||||
 | 
				
			|||||||
@ -123,8 +123,8 @@ class CookieClientManager(ClientManager):
 | 
				
			|||||||
    async def _choose_cookie(self, target: Target | None) -> Cookie:
 | 
					    async def _choose_cookie(self, target: Target | None) -> Cookie:
 | 
				
			||||||
        """选择 cookie 的具体算法"""
 | 
					        """选择 cookie 的具体算法"""
 | 
				
			||||||
        cookies = await config.get_cookie(self._site_name, target)
 | 
					        cookies = await config.get_cookie(self._site_name, target)
 | 
				
			||||||
        cookies = (cookie for cookie in cookies if cookie.last_usage + cookie.cd < datetime.now())
 | 
					        avaliable_cookies = (cookie for cookie in cookies if cookie.last_usage + cookie.cd < datetime.now())
 | 
				
			||||||
        cookie = min(cookies, key=lambda x: x.last_usage)
 | 
					        cookie = min(avaliable_cookies, key=lambda x: x.last_usage)
 | 
				
			||||||
        return cookie
 | 
					        return cookie
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def get_client(self, target: Target | None) -> AsyncClient:
 | 
					    async def get_client(self, target: Target | None) -> AsyncClient:
 | 
				
			||||||
@ -183,8 +183,8 @@ class SiteMeta(type):
 | 
				
			|||||||
            cls._key = kwargs.get("key")
 | 
					            cls._key = kwargs.get("key")
 | 
				
			||||||
        elif not kwargs.get("abstract"):
 | 
					        elif not kwargs.get("abstract"):
 | 
				
			||||||
            # this is the subclass
 | 
					            # this is the subclass
 | 
				
			||||||
            if hasattr(cls, "name"):
 | 
					            if "name" in namespace:
 | 
				
			||||||
                site_manager[cls.name] = cls
 | 
					                site_manager[namespace["name"]] = cls
 | 
				
			||||||
        super().__init__(name, bases, namespace, **kwargs)
 | 
					        super().__init__(name, bases, namespace, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -78,7 +78,7 @@ async def test_subs_export(app: App, tmp_path: Path):
 | 
				
			|||||||
            cookie_name="test cookie",
 | 
					            cookie_name="test cookie",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    await config.add_cookie_target("weibo_id", "weibo", cookie_id)
 | 
					    await config.add_cookie_target(TTarget("weibo_id"), "weibo", cookie_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert len(await config.list_subs_with_all_info()) == 3
 | 
					    assert len(await config.list_subs_with_all_info()) == 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -16,7 +16,7 @@ async def test_subs_export(app: App, init_scheduler):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    await config.add_subscribe(
 | 
					    await config.add_subscribe(
 | 
				
			||||||
        TargetQQGroup(group_id=1232),
 | 
					        TargetQQGroup(group_id=1232),
 | 
				
			||||||
        target=TTarget("weibo_id"),
 | 
					        target=TTarget(TTarget("weibo_id")),
 | 
				
			||||||
        target_name="weibo_name",
 | 
					        target_name="weibo_name",
 | 
				
			||||||
        platform_name="weibo",
 | 
					        platform_name="weibo",
 | 
				
			||||||
        cats=[],
 | 
					        cats=[],
 | 
				
			||||||
@ -24,7 +24,7 @@ async def test_subs_export(app: App, init_scheduler):
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
    await config.add_subscribe(
 | 
					    await config.add_subscribe(
 | 
				
			||||||
        TargetQQGroup(group_id=2342),
 | 
					        TargetQQGroup(group_id=2342),
 | 
				
			||||||
        target=TTarget("weibo_id"),
 | 
					        target=TTarget(TTarget("weibo_id")),
 | 
				
			||||||
        target_name="weibo_name",
 | 
					        target_name="weibo_name",
 | 
				
			||||||
        platform_name="weibo",
 | 
					        platform_name="weibo",
 | 
				
			||||||
        cats=[],
 | 
					        cats=[],
 | 
				
			||||||
@ -45,7 +45,7 @@ async def test_subs_export(app: App, init_scheduler):
 | 
				
			|||||||
            cookie_name="test cookie",
 | 
					            cookie_name="test cookie",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    await config.add_cookie_target("weibo_id", "weibo", cookie_id)
 | 
					    await config.add_cookie_target(TTarget("weibo_id"), "weibo", cookie_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    data = await config.list_subs_with_all_info()
 | 
					    data = await config.list_subs_with_all_info()
 | 
				
			||||||
    assert len(data) == 3
 | 
					    assert len(data) == 3
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user