♻️ 重构 get_cookie 方法

This commit is contained in:
2024-09-13 11:34:45 +08:00
parent 4b8d6a9379
commit 4791fb69e0
6 changed files with 73 additions and 123 deletions
+16 -35
View File
@@ -259,36 +259,32 @@ class DBConfig:
)
return res
async def get_cookie(self, site_name: str = None, target: T_Target = None) -> Sequence[Cookie]:
"""根据 site_name 和 target 获取 cookie,不会返回匿名cookie"""
async def get_cookie(
self,
site_name: str | None = None,
target: T_Target | None = None,
is_universal: bool | None = None,
is_anonymous: bool | None = None,
) -> Sequence[Cookie]:
"""获取满足传入条件的所有 cookie"""
async with create_session() as sess:
query = select(Cookie).distinct().where(Cookie.is_universal == False) # noqa: E712
query = select(Cookie).distinct()
if is_universal is not None:
query = query.where(Cookie.is_universal == is_universal)
if is_anonymous is not None:
query = query.where(Cookie.is_anonymous == is_anonymous)
if site_name:
query = query.where(Cookie.site_name == site_name)
query = query.outerjoin(CookieTarget).options(selectinload(Cookie.targets))
res = (await sess.scalars(query)).all()
if target:
# 如果指定了 target,过滤掉不满足要求的cookie
query = select(CookieTarget.cookie_id).join(Target).where(Target.target == target)
ids = set((await sess.scalars(query)).all())
res = [cookie for cookie in res if cookie.id in ids]
# 如果指定了 target 且未指定 is_universal,则添加返回 universal cookie
res = [cookie for cookie in res if cookie.id in ids or cookie.is_universal]
return res
async def get_unviersal_cookie(self, site_name: str = None) -> Sequence[Cookie]:
async with create_session() as sess:
query = select(Cookie).distinct().where(Cookie.is_universal == True) # noqa: E712
if site_name:
query = query.where(Cookie.site_name == site_name)
res = (await sess.scalars(query)).all()
return res
async def add_cookie_with_content(self, site_name: str, content: str) -> int:
async with create_session() as sess:
cookie = Cookie(site_name=site_name, content=content)
sess.add(cookie)
await sess.commit()
await sess.refresh(cookie)
return cookie.id
async def add_cookie(self, cookie: Cookie) -> int:
async with create_session() as sess:
sess.add(cookie)
@@ -320,21 +316,6 @@ class DBConfig:
await sess.execute(delete(Cookie).where(Cookie.id == cookie_id))
await sess.commit()
async def get_cookie_by_target(self, target: T_Target, site_name: str) -> Sequence[Cookie]:
async with create_session() as sess:
query = (
select(Cookie)
.join(CookieTarget)
.join(Target)
.where(Target.site_name == site_name, Target.target == target)
)
return (await sess.scalars(query)).all()
async def get_universal_cookie(self, site_name: str) -> Sequence[Cookie]:
async with create_session() as sess:
query = select(Cookie).where(Cookie.site_name == site_name).where(Cookie.is_universal == True) # noqa: E712
return (await sess.scalars(query)).all()
async def add_cookie_target(self, target: T_Target, platform_name: str, cookie_id: int):
"""通过 cookie_id 可以唯一确定一个 Cookie,通过 target 和 platform_name 可以唯一确定一个 Target"""
async with create_session() as sess:
@@ -36,10 +36,12 @@ def do_add_cookie_target(add_cookie_target_matcher: type[Matcher]):
@add_cookie_target_matcher.handle()
async def init_promote_cookie(state: T_State):
cookies = await config.get_cookie(site_name=state["site"].name)
# 获取 site 的所有用户 cookie,再排除掉已经关联的 cookie,剩下的就是可以关联的 cookie
cookies = await config.get_cookie(site_name=state["site"].name, is_anonymous=False)
associated_cookies = await config.get_cookie(
target=state["target"]["target"],
site_name=state["site"].name,
is_anonymous=False,
)
associated_cookie_ids = {cookie.id for cookie in associated_cookies}
cookies = [cookie for cookie in cookies if cookie.id not in associated_cookie_ids]
+1 -1
View File
@@ -14,7 +14,7 @@ def do_del_cookie(del_cookie: type[Matcher]):
@del_cookie.handle()
async def send_list(state: T_State):
cookies = await config.get_cookie()
cookies = await config.get_cookie(is_anonymous=False)
if not cookies:
await del_cookie.finish("暂无已添加 Cookie\n请使用“添加cookie”命令添加")
res = "已添加的 Cookie 为:\n"
+3 -1
View File
@@ -108,7 +108,9 @@ async def generate_sub_list_text(
if sub.tags:
res += " {}".format(", ".join(sub.tags)) + "\n"
if is_show_cookie:
target_cookies = await config.get_cookie(target=sub.target.target, site_name=platform.site.name)
target_cookies = await config.get_cookie(
target=sub.target.target, site_name=platform.site.name, is_anonymous=False
)
if target_cookies:
res += " 关联的 Cookie\n"
for cookie in target_cookies:
+3 -5
View File
@@ -49,7 +49,7 @@ class CookieClientManager(ClientManager):
@classmethod
async def refresh_anonymous_cookie(cls):
"""移除已有的匿名cookie,添加一个新的匿名cookie"""
anonymous_cookies = await config.get_unviersal_cookie(cls._site_name)
anonymous_cookies = await config.get_cookie(cls._site_name, is_anonymous=True)
anonymous_cookie = Cookie(site_name=cls._site_name, content="{}", is_universal=True, is_anonymous=True)
for cookie in anonymous_cookies:
if not cookie.is_anonymous:
@@ -99,11 +99,9 @@ class CookieClientManager(ClientManager):
return _response_hook
async def _choose_cookie(self, target: Target) -> Cookie:
async def _choose_cookie(self, target: Target | None) -> Cookie:
"""选择 cookie 的具体算法"""
cookies = await config.get_universal_cookie(self._site_name)
if target:
cookies += await config.get_cookie(self._site_name, target)
cookies = await config.get_cookie(self._site_name, target)
cookies = (cookie for cookie in cookies if cookie.last_usage + cookie.cd < datetime.now())
cookie = min(cookies, key=lambda x: x.last_usage)
return cookie