diff --git a/nonebot_bison/config/db_config.py b/nonebot_bison/config/db_config.py index 92da96c..2f85775 100644 --- a/nonebot_bison/config/db_config.py +++ b/nonebot_bison/config/db_config.py @@ -259,36 +259,32 @@ class DBConfig: ) return res - async def get_cookie(self, site_name: str = None, target: T_Target = None) -> Sequence[Cookie]: - """根据 site_name 和 target 获取 cookie,不会返回匿名cookie""" + async def get_cookie( + self, + site_name: str | None = None, + target: T_Target | None = None, + is_universal: bool | None = None, + is_anonymous: bool | None = None, + ) -> Sequence[Cookie]: + """获取满足传入条件的所有 cookie""" async with create_session() as sess: - query = select(Cookie).distinct().where(Cookie.is_universal == False) # noqa: E712 + query = select(Cookie).distinct() + if is_universal is not None: + query = query.where(Cookie.is_universal == is_universal) + if is_anonymous is not None: + query = query.where(Cookie.is_anonymous == is_anonymous) if site_name: query = query.where(Cookie.site_name == site_name) query = query.outerjoin(CookieTarget).options(selectinload(Cookie.targets)) res = (await sess.scalars(query)).all() if target: + # 如果指定了 target,过滤掉不满足要求的cookie query = select(CookieTarget.cookie_id).join(Target).where(Target.target == target) ids = set((await sess.scalars(query)).all()) - res = [cookie for cookie in res if cookie.id in ids] + # 如果指定了 target 且未指定 is_universal,则添加返回 universal cookie + res = [cookie for cookie in res if cookie.id in ids or cookie.is_universal] return res - async def get_unviersal_cookie(self, site_name: str = None) -> Sequence[Cookie]: - async with create_session() as sess: - query = select(Cookie).distinct().where(Cookie.is_universal == True) # noqa: E712 - if site_name: - query = query.where(Cookie.site_name == site_name) - res = (await sess.scalars(query)).all() - return res - - async def add_cookie_with_content(self, site_name: str, content: str) -> int: - async with create_session() as sess: - cookie = Cookie(site_name=site_name, content=content) - sess.add(cookie) - await sess.commit() - await sess.refresh(cookie) - return cookie.id - async def add_cookie(self, cookie: Cookie) -> int: async with create_session() as sess: sess.add(cookie) @@ -320,21 +316,6 @@ class DBConfig: await sess.execute(delete(Cookie).where(Cookie.id == cookie_id)) await sess.commit() - async def get_cookie_by_target(self, target: T_Target, site_name: str) -> Sequence[Cookie]: - async with create_session() as sess: - query = ( - select(Cookie) - .join(CookieTarget) - .join(Target) - .where(Target.site_name == site_name, Target.target == target) - ) - return (await sess.scalars(query)).all() - - async def get_universal_cookie(self, site_name: str) -> Sequence[Cookie]: - async with create_session() as sess: - query = select(Cookie).where(Cookie.site_name == site_name).where(Cookie.is_universal == True) # noqa: E712 - return (await sess.scalars(query)).all() - async def add_cookie_target(self, target: T_Target, platform_name: str, cookie_id: int): """通过 cookie_id 可以唯一确定一个 Cookie,通过 target 和 platform_name 可以唯一确定一个 Target""" async with create_session() as sess: diff --git a/nonebot_bison/sub_manager/add_cookie_target.py b/nonebot_bison/sub_manager/add_cookie_target.py index 9d31593..73d72e2 100644 --- a/nonebot_bison/sub_manager/add_cookie_target.py +++ b/nonebot_bison/sub_manager/add_cookie_target.py @@ -36,10 +36,12 @@ def do_add_cookie_target(add_cookie_target_matcher: type[Matcher]): @add_cookie_target_matcher.handle() async def init_promote_cookie(state: T_State): - cookies = await config.get_cookie(site_name=state["site"].name) + # 获取 site 的所有用户 cookie,再排除掉已经关联的 cookie,剩下的就是可以关联的 cookie + cookies = await config.get_cookie(site_name=state["site"].name, is_anonymous=False) associated_cookies = await config.get_cookie( target=state["target"]["target"], site_name=state["site"].name, + is_anonymous=False, ) associated_cookie_ids = {cookie.id for cookie in associated_cookies} cookies = [cookie for cookie in cookies if cookie.id not in associated_cookie_ids] diff --git a/nonebot_bison/sub_manager/del_cookie.py b/nonebot_bison/sub_manager/del_cookie.py index 6b3056f..e343d4f 100644 --- a/nonebot_bison/sub_manager/del_cookie.py +++ b/nonebot_bison/sub_manager/del_cookie.py @@ -14,7 +14,7 @@ def do_del_cookie(del_cookie: type[Matcher]): @del_cookie.handle() async def send_list(state: T_State): - cookies = await config.get_cookie() + cookies = await config.get_cookie(is_anonymous=False) if not cookies: await del_cookie.finish("暂无已添加 Cookie\n请使用“添加cookie”命令添加") res = "已添加的 Cookie 为:\n" diff --git a/nonebot_bison/sub_manager/utils.py b/nonebot_bison/sub_manager/utils.py index ad6cb5e..a74d8ac 100644 --- a/nonebot_bison/sub_manager/utils.py +++ b/nonebot_bison/sub_manager/utils.py @@ -108,7 +108,9 @@ async def generate_sub_list_text( if sub.tags: res += " {}".format(", ".join(sub.tags)) + "\n" if is_show_cookie: - target_cookies = await config.get_cookie(target=sub.target.target, site_name=platform.site.name) + target_cookies = await config.get_cookie( + target=sub.target.target, site_name=platform.site.name, is_anonymous=False + ) if target_cookies: res += " 关联的 Cookie:\n" for cookie in target_cookies: diff --git a/nonebot_bison/utils/site.py b/nonebot_bison/utils/site.py index 2da4422..b934cf4 100644 --- a/nonebot_bison/utils/site.py +++ b/nonebot_bison/utils/site.py @@ -49,7 +49,7 @@ class CookieClientManager(ClientManager): @classmethod async def refresh_anonymous_cookie(cls): """移除已有的匿名cookie,添加一个新的匿名cookie""" - anonymous_cookies = await config.get_unviersal_cookie(cls._site_name) + anonymous_cookies = await config.get_cookie(cls._site_name, is_anonymous=True) anonymous_cookie = Cookie(site_name=cls._site_name, content="{}", is_universal=True, is_anonymous=True) for cookie in anonymous_cookies: if not cookie.is_anonymous: @@ -99,11 +99,9 @@ class CookieClientManager(ClientManager): return _response_hook - async def _choose_cookie(self, target: Target) -> Cookie: + async def _choose_cookie(self, target: Target | None) -> Cookie: """选择 cookie 的具体算法""" - cookies = await config.get_universal_cookie(self._site_name) - if target: - cookies += await config.get_cookie(self._site_name, target) + cookies = await config.get_cookie(self._site_name, target) cookies = (cookie for cookie in cookies if cookie.last_usage + cookie.cd < datetime.now()) cookie = min(cookies, key=lambda x: x.last_usage) return cookie diff --git a/tests/config/test_cookie.py b/tests/config/test_cookie.py index b13fa3f..7db0b47 100644 --- a/tests/config/test_cookie.py +++ b/tests/config/test_cookie.py @@ -1,107 +1,74 @@ -import datetime +import json +from typing import cast from nonebug import App -async def test_get_platform_target(app: App, init_scheduler): +async def test_cookie(app: App, init_scheduler): from nonebot_plugin_saa import TargetQQGroup + from nonebot_bison.platform import site_manager from nonebot_bison.config.db_config import config from nonebot_bison.types import Target as T_Target + from nonebot_bison.utils.site import CookieClientManager + target = T_Target("weibo_id") + platform_name = "weibo" await config.add_subscribe( TargetQQGroup(group_id=123), - target=T_Target("weibo_id"), + target=target, target_name="weibo_name", - platform_name="weibo", + platform_name=platform_name, cats=[], tags=[], ) - # await config.add_cookie(TargetQQGroup(group_id=123), "weibo", "cookie") - # cookies = await config.get_cookie_by_user(TargetQQGroup(group_id=123)) - # - # res = await config.get_platform_target("weibo") - # assert len(res) == 2 - # await config.del_subscribe(TargetQQGroup(group_id=123), T_Target("weibo_id1"), "weibo") - # res = await config.get_platform_target("weibo") - # assert len(res) == 2 - # await config.del_subscribe(TargetQQGroup(group_id=123), T_Target("weibo_id"), "weibo") - # res = await config.get_platform_target("weibo") - # assert len(res) == 1 - # - # async with AsyncSession(get_engine()) as sess: - # res = await sess.scalars(select(Target).where(Target.platform_name == "weibo")) - # assert len(res.all()) == 2 - # await config.get_cookie_by_user(TargetQQGroup(group_id=123)) + site = site_manager["weibo.com"] + client_mgr = cast(CookieClientManager, site.client_mgr) + await client_mgr.refresh_anonymous_cookie() # 刷新匿名cookie -async def test_cookie_by_user(app: App, init_scheduler): - from nonebot_plugin_saa import TargetQQGroup + cookies = await config.get_cookie(site_name=site.name) + assert len(cookies) == 1 - from nonebot_bison.config.db_config import config - from nonebot_bison.types import Target as T_Target + await client_mgr.add_user_cookie(json.dumps({"test_cookie": "1"})) + await client_mgr.add_user_cookie(json.dumps({"test_cookie": "2"})) + cookies = await config.get_cookie(site_name=site.name) + assert len(cookies) == 3 + + cookies = await config.get_cookie(site_name=site.name, is_anonymous=False) + assert len(cookies) == 2 + + await config.add_cookie_target(target, platform_name, cookies[0].id) + await config.add_cookie_target(target, platform_name, cookies[1].id) + + cookies = await config.get_cookie(site_name=site.name, target=target) + assert len(cookies) == 3 + + cookies = await config.get_cookie(site_name=site.name, target=target, is_anonymous=False) + assert len(cookies) == 2 + + cookies = await config.get_cookie(site_name=site.name, target=target, is_universal=False) + assert len(cookies) == 2 + + # 测试不同的target + target2 = T_Target("weibo_id2") await config.add_subscribe( TargetQQGroup(group_id=123), - target=T_Target("weibo_id"), - target_name="weibo_name", - platform_name="weibo", + target=target2, + target_name="weibo_name2", + platform_name=platform_name, cats=[], tags=[], ) + await client_mgr.add_user_cookie(json.dumps({"test_cookie": "3"})) + cookies = await config.get_cookie(site_name=site.name, is_anonymous=False) - await config.add_cookie_with_content(TargetQQGroup(group_id=123), "weibo", "cookie") + await config.add_cookie_target(target2, platform_name, cookies[0].id) + await config.add_cookie_target(target2, platform_name, cookies[2].id) - cookies = await config.get_cookie(TargetQQGroup(group_id=123)) - cookie = cookies[0] - assert len(cookies) == 1 - assert cookie.content == "cookie" - assert cookie.platform_name == "weibo" - cookie.last_usage = 0 - assert cookie.status == "" - assert cookie.tags == {} - cookie.content = "cookie1" - cookie.last_usage = datetime.datetime(2024, 8, 22, 0, 0, 0) - cookie.status = "status1" - cookie.tags = {"tag1": "value1"} - await config.update_cookie(cookie) - cookies = await config.get_cookie(TargetQQGroup(group_id=123)) + cookies = await config.get_cookie(site_name=site.name, target=target2) + assert len(cookies) == 3 - assert len(cookies) == 1 - assert cookies[0].content == cookie.content - assert cookies[0].last_usage == cookie.last_usage - assert cookies[0].status == cookie.status - assert cookies[0].tags == cookie.tags - - await config.delete_cookie_by_id(cookies[0].id) - cookies = await config.get_cookie(TargetQQGroup(group_id=123)) - assert len(cookies) == 0 - - -async def test_cookie_target_by_target(app: App, init_scheduler): - from nonebot_plugin_saa import TargetQQGroup - - from nonebot_bison.config.db_config import config - from nonebot_bison.types import Target as T_Target - - await config.add_subscribe( - TargetQQGroup(group_id=123), - target=T_Target("weibo_id"), - target_name="weibo_name", - platform_name="weibo", - cats=[], - tags=[], - ) - - id = await config.add_cookie_with_content(TargetQQGroup(group_id=123), "weibo", "cookie") - - await config.add_cookie_target(T_Target("weibo_id"), "weibo", id) - - cookies = await config.get_cookie_by_target(T_Target("weibo_id"), "weibo") - assert len(cookies) == 1 - assert cookies[0].content == "cookie" - assert cookies[0].platform_name == "weibo" - - await config.delete_cookie_target(T_Target("weibo_id"), "weibo", id) - cookies = await config.get_cookie_by_target(T_Target("weibo_id"), "weibo") - assert len(cookies) == 0 + cookies = await config.get_cookie(site_name=site.name, target=target2, is_anonymous=False) + assert len(cookies) == 2