mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-07-16 13:22:59 +08:00
♻️ 重构 get_cookie 方法
This commit is contained in:
parent
4b8d6a9379
commit
4791fb69e0
nonebot_bison
tests/config
@ -259,36 +259,32 @@ class DBConfig:
|
||||
)
|
||||
return res
|
||||
|
||||
async def get_cookie(self, site_name: str = None, target: T_Target = None) -> Sequence[Cookie]:
|
||||
"""根据 site_name 和 target 获取 cookie,不会返回匿名cookie"""
|
||||
async def get_cookie(
|
||||
self,
|
||||
site_name: str | None = None,
|
||||
target: T_Target | None = None,
|
||||
is_universal: bool | None = None,
|
||||
is_anonymous: bool | None = None,
|
||||
) -> Sequence[Cookie]:
|
||||
"""获取满足传入条件的所有 cookie"""
|
||||
async with create_session() as sess:
|
||||
query = select(Cookie).distinct().where(Cookie.is_universal == False) # noqa: E712
|
||||
query = select(Cookie).distinct()
|
||||
if is_universal is not None:
|
||||
query = query.where(Cookie.is_universal == is_universal)
|
||||
if is_anonymous is not None:
|
||||
query = query.where(Cookie.is_anonymous == is_anonymous)
|
||||
if site_name:
|
||||
query = query.where(Cookie.site_name == site_name)
|
||||
query = query.outerjoin(CookieTarget).options(selectinload(Cookie.targets))
|
||||
res = (await sess.scalars(query)).all()
|
||||
if target:
|
||||
# 如果指定了 target,过滤掉不满足要求的cookie
|
||||
query = select(CookieTarget.cookie_id).join(Target).where(Target.target == target)
|
||||
ids = set((await sess.scalars(query)).all())
|
||||
res = [cookie for cookie in res if cookie.id in ids]
|
||||
# 如果指定了 target 且未指定 is_universal,则添加返回 universal cookie
|
||||
res = [cookie for cookie in res if cookie.id in ids or cookie.is_universal]
|
||||
return res
|
||||
|
||||
async def get_unviersal_cookie(self, site_name: str = None) -> Sequence[Cookie]:
|
||||
async with create_session() as sess:
|
||||
query = select(Cookie).distinct().where(Cookie.is_universal == True) # noqa: E712
|
||||
if site_name:
|
||||
query = query.where(Cookie.site_name == site_name)
|
||||
res = (await sess.scalars(query)).all()
|
||||
return res
|
||||
|
||||
async def add_cookie_with_content(self, site_name: str, content: str) -> int:
|
||||
async with create_session() as sess:
|
||||
cookie = Cookie(site_name=site_name, content=content)
|
||||
sess.add(cookie)
|
||||
await sess.commit()
|
||||
await sess.refresh(cookie)
|
||||
return cookie.id
|
||||
|
||||
async def add_cookie(self, cookie: Cookie) -> int:
|
||||
async with create_session() as sess:
|
||||
sess.add(cookie)
|
||||
@ -320,21 +316,6 @@ class DBConfig:
|
||||
await sess.execute(delete(Cookie).where(Cookie.id == cookie_id))
|
||||
await sess.commit()
|
||||
|
||||
async def get_cookie_by_target(self, target: T_Target, site_name: str) -> Sequence[Cookie]:
|
||||
async with create_session() as sess:
|
||||
query = (
|
||||
select(Cookie)
|
||||
.join(CookieTarget)
|
||||
.join(Target)
|
||||
.where(Target.site_name == site_name, Target.target == target)
|
||||
)
|
||||
return (await sess.scalars(query)).all()
|
||||
|
||||
async def get_universal_cookie(self, site_name: str) -> Sequence[Cookie]:
|
||||
async with create_session() as sess:
|
||||
query = select(Cookie).where(Cookie.site_name == site_name).where(Cookie.is_universal == True) # noqa: E712
|
||||
return (await sess.scalars(query)).all()
|
||||
|
||||
async def add_cookie_target(self, target: T_Target, platform_name: str, cookie_id: int):
|
||||
"""通过 cookie_id 可以唯一确定一个 Cookie,通过 target 和 platform_name 可以唯一确定一个 Target"""
|
||||
async with create_session() as sess:
|
||||
|
@ -36,10 +36,12 @@ def do_add_cookie_target(add_cookie_target_matcher: type[Matcher]):
|
||||
@add_cookie_target_matcher.handle()
|
||||
async def init_promote_cookie(state: T_State):
|
||||
|
||||
cookies = await config.get_cookie(site_name=state["site"].name)
|
||||
# 获取 site 的所有用户 cookie,再排除掉已经关联的 cookie,剩下的就是可以关联的 cookie
|
||||
cookies = await config.get_cookie(site_name=state["site"].name, is_anonymous=False)
|
||||
associated_cookies = await config.get_cookie(
|
||||
target=state["target"]["target"],
|
||||
site_name=state["site"].name,
|
||||
is_anonymous=False,
|
||||
)
|
||||
associated_cookie_ids = {cookie.id for cookie in associated_cookies}
|
||||
cookies = [cookie for cookie in cookies if cookie.id not in associated_cookie_ids]
|
||||
|
@ -14,7 +14,7 @@ def do_del_cookie(del_cookie: type[Matcher]):
|
||||
|
||||
@del_cookie.handle()
|
||||
async def send_list(state: T_State):
|
||||
cookies = await config.get_cookie()
|
||||
cookies = await config.get_cookie(is_anonymous=False)
|
||||
if not cookies:
|
||||
await del_cookie.finish("暂无已添加 Cookie\n请使用“添加cookie”命令添加")
|
||||
res = "已添加的 Cookie 为:\n"
|
||||
|
@ -108,7 +108,9 @@ async def generate_sub_list_text(
|
||||
if sub.tags:
|
||||
res += " {}".format(", ".join(sub.tags)) + "\n"
|
||||
if is_show_cookie:
|
||||
target_cookies = await config.get_cookie(target=sub.target.target, site_name=platform.site.name)
|
||||
target_cookies = await config.get_cookie(
|
||||
target=sub.target.target, site_name=platform.site.name, is_anonymous=False
|
||||
)
|
||||
if target_cookies:
|
||||
res += " 关联的 Cookie:\n"
|
||||
for cookie in target_cookies:
|
||||
|
@ -49,7 +49,7 @@ class CookieClientManager(ClientManager):
|
||||
@classmethod
|
||||
async def refresh_anonymous_cookie(cls):
|
||||
"""移除已有的匿名cookie,添加一个新的匿名cookie"""
|
||||
anonymous_cookies = await config.get_unviersal_cookie(cls._site_name)
|
||||
anonymous_cookies = await config.get_cookie(cls._site_name, is_anonymous=True)
|
||||
anonymous_cookie = Cookie(site_name=cls._site_name, content="{}", is_universal=True, is_anonymous=True)
|
||||
for cookie in anonymous_cookies:
|
||||
if not cookie.is_anonymous:
|
||||
@ -99,11 +99,9 @@ class CookieClientManager(ClientManager):
|
||||
|
||||
return _response_hook
|
||||
|
||||
async def _choose_cookie(self, target: Target) -> Cookie:
|
||||
async def _choose_cookie(self, target: Target | None) -> Cookie:
|
||||
"""选择 cookie 的具体算法"""
|
||||
cookies = await config.get_universal_cookie(self._site_name)
|
||||
if target:
|
||||
cookies += await config.get_cookie(self._site_name, target)
|
||||
cookies = await config.get_cookie(self._site_name, target)
|
||||
cookies = (cookie for cookie in cookies if cookie.last_usage + cookie.cd < datetime.now())
|
||||
cookie = min(cookies, key=lambda x: x.last_usage)
|
||||
return cookie
|
||||
|
@ -1,107 +1,74 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from nonebug import App
|
||||
|
||||
|
||||
async def test_get_platform_target(app: App, init_scheduler):
|
||||
async def test_cookie(app: App, init_scheduler):
|
||||
from nonebot_plugin_saa import TargetQQGroup
|
||||
|
||||
from nonebot_bison.platform import site_manager
|
||||
from nonebot_bison.config.db_config import config
|
||||
from nonebot_bison.types import Target as T_Target
|
||||
from nonebot_bison.utils.site import CookieClientManager
|
||||
|
||||
target = T_Target("weibo_id")
|
||||
platform_name = "weibo"
|
||||
await config.add_subscribe(
|
||||
TargetQQGroup(group_id=123),
|
||||
target=T_Target("weibo_id"),
|
||||
target=target,
|
||||
target_name="weibo_name",
|
||||
platform_name="weibo",
|
||||
platform_name=platform_name,
|
||||
cats=[],
|
||||
tags=[],
|
||||
)
|
||||
# await config.add_cookie(TargetQQGroup(group_id=123), "weibo", "cookie")
|
||||
# cookies = await config.get_cookie_by_user(TargetQQGroup(group_id=123))
|
||||
#
|
||||
# res = await config.get_platform_target("weibo")
|
||||
# assert len(res) == 2
|
||||
# await config.del_subscribe(TargetQQGroup(group_id=123), T_Target("weibo_id1"), "weibo")
|
||||
# res = await config.get_platform_target("weibo")
|
||||
# assert len(res) == 2
|
||||
# await config.del_subscribe(TargetQQGroup(group_id=123), T_Target("weibo_id"), "weibo")
|
||||
# res = await config.get_platform_target("weibo")
|
||||
# assert len(res) == 1
|
||||
#
|
||||
# async with AsyncSession(get_engine()) as sess:
|
||||
# res = await sess.scalars(select(Target).where(Target.platform_name == "weibo"))
|
||||
# assert len(res.all()) == 2
|
||||
# await config.get_cookie_by_user(TargetQQGroup(group_id=123))
|
||||
site = site_manager["weibo.com"]
|
||||
client_mgr = cast(CookieClientManager, site.client_mgr)
|
||||
|
||||
await client_mgr.refresh_anonymous_cookie() # 刷新匿名cookie
|
||||
|
||||
async def test_cookie_by_user(app: App, init_scheduler):
|
||||
from nonebot_plugin_saa import TargetQQGroup
|
||||
cookies = await config.get_cookie(site_name=site.name)
|
||||
assert len(cookies) == 1
|
||||
|
||||
from nonebot_bison.config.db_config import config
|
||||
from nonebot_bison.types import Target as T_Target
|
||||
await client_mgr.add_user_cookie(json.dumps({"test_cookie": "1"}))
|
||||
await client_mgr.add_user_cookie(json.dumps({"test_cookie": "2"}))
|
||||
|
||||
cookies = await config.get_cookie(site_name=site.name)
|
||||
assert len(cookies) == 3
|
||||
|
||||
cookies = await config.get_cookie(site_name=site.name, is_anonymous=False)
|
||||
assert len(cookies) == 2
|
||||
|
||||
await config.add_cookie_target(target, platform_name, cookies[0].id)
|
||||
await config.add_cookie_target(target, platform_name, cookies[1].id)
|
||||
|
||||
cookies = await config.get_cookie(site_name=site.name, target=target)
|
||||
assert len(cookies) == 3
|
||||
|
||||
cookies = await config.get_cookie(site_name=site.name, target=target, is_anonymous=False)
|
||||
assert len(cookies) == 2
|
||||
|
||||
cookies = await config.get_cookie(site_name=site.name, target=target, is_universal=False)
|
||||
assert len(cookies) == 2
|
||||
|
||||
# 测试不同的target
|
||||
target2 = T_Target("weibo_id2")
|
||||
await config.add_subscribe(
|
||||
TargetQQGroup(group_id=123),
|
||||
target=T_Target("weibo_id"),
|
||||
target_name="weibo_name",
|
||||
platform_name="weibo",
|
||||
target=target2,
|
||||
target_name="weibo_name2",
|
||||
platform_name=platform_name,
|
||||
cats=[],
|
||||
tags=[],
|
||||
)
|
||||
await client_mgr.add_user_cookie(json.dumps({"test_cookie": "3"}))
|
||||
cookies = await config.get_cookie(site_name=site.name, is_anonymous=False)
|
||||
|
||||
await config.add_cookie_with_content(TargetQQGroup(group_id=123), "weibo", "cookie")
|
||||
await config.add_cookie_target(target2, platform_name, cookies[0].id)
|
||||
await config.add_cookie_target(target2, platform_name, cookies[2].id)
|
||||
|
||||
cookies = await config.get_cookie(TargetQQGroup(group_id=123))
|
||||
cookie = cookies[0]
|
||||
assert len(cookies) == 1
|
||||
assert cookie.content == "cookie"
|
||||
assert cookie.platform_name == "weibo"
|
||||
cookie.last_usage = 0
|
||||
assert cookie.status == ""
|
||||
assert cookie.tags == {}
|
||||
cookie.content = "cookie1"
|
||||
cookie.last_usage = datetime.datetime(2024, 8, 22, 0, 0, 0)
|
||||
cookie.status = "status1"
|
||||
cookie.tags = {"tag1": "value1"}
|
||||
await config.update_cookie(cookie)
|
||||
cookies = await config.get_cookie(TargetQQGroup(group_id=123))
|
||||
cookies = await config.get_cookie(site_name=site.name, target=target2)
|
||||
assert len(cookies) == 3
|
||||
|
||||
assert len(cookies) == 1
|
||||
assert cookies[0].content == cookie.content
|
||||
assert cookies[0].last_usage == cookie.last_usage
|
||||
assert cookies[0].status == cookie.status
|
||||
assert cookies[0].tags == cookie.tags
|
||||
|
||||
await config.delete_cookie_by_id(cookies[0].id)
|
||||
cookies = await config.get_cookie(TargetQQGroup(group_id=123))
|
||||
assert len(cookies) == 0
|
||||
|
||||
|
||||
async def test_cookie_target_by_target(app: App, init_scheduler):
|
||||
from nonebot_plugin_saa import TargetQQGroup
|
||||
|
||||
from nonebot_bison.config.db_config import config
|
||||
from nonebot_bison.types import Target as T_Target
|
||||
|
||||
await config.add_subscribe(
|
||||
TargetQQGroup(group_id=123),
|
||||
target=T_Target("weibo_id"),
|
||||
target_name="weibo_name",
|
||||
platform_name="weibo",
|
||||
cats=[],
|
||||
tags=[],
|
||||
)
|
||||
|
||||
id = await config.add_cookie_with_content(TargetQQGroup(group_id=123), "weibo", "cookie")
|
||||
|
||||
await config.add_cookie_target(T_Target("weibo_id"), "weibo", id)
|
||||
|
||||
cookies = await config.get_cookie_by_target(T_Target("weibo_id"), "weibo")
|
||||
assert len(cookies) == 1
|
||||
assert cookies[0].content == "cookie"
|
||||
assert cookies[0].platform_name == "weibo"
|
||||
|
||||
await config.delete_cookie_target(T_Target("weibo_id"), "weibo", id)
|
||||
cookies = await config.get_cookie_by_target(T_Target("weibo_id"), "weibo")
|
||||
assert len(cookies) == 0
|
||||
cookies = await config.get_cookie(site_name=site.name, target=target2, is_anonymous=False)
|
||||
assert len(cookies) == 2
|
||||
|
Loading…
x
Reference in New Issue
Block a user