支持对话关联cookie到订阅目标

This commit is contained in:
suyiiyii 2024-08-26 17:28:59 +08:00
parent b655eff755
commit 6f20dbf358
7 changed files with 141 additions and 6 deletions

View File

@ -1,4 +1,5 @@
from .types import Target
from .config.db_model import Cookie
from .scheduler import scheduler_dict
from .platform import platform_manager
@ -15,3 +16,8 @@ async def check_sub_target(platform_name: str, target: Target):
async def check_sub_target_cookie(platform_name: str, target: Target, cookie: str):
# TODO
return "check pass"
async def get_cookie_friendly_name(cookie: Cookie):
# TODO
return cookie.platform_name + cookie.content[:10]

View File

@ -12,7 +12,7 @@ from nonebot_plugin_datastore import create_session
from ..types import Tag
from ..types import Target as T_Target
from .utils import NoSuchTargetException
from .utils import NoSuchTargetException, DuplicateCookieTargetException
from .db_model import User, Cookie, Target, Subscribe, CookieTarget, ScheduleTimeWeight
from ..types import Category, UserSubInfo, WeightConfig, TimeWeightConfig, PlatformWeightConfigResp
@ -259,12 +259,31 @@ class DBConfig:
)
return res
async def get_cookie_by_user(self, user: PlatformTarget) -> list[Cookie]:
async def get_cookie(
self, user: PlatformTarget = None, platform_name: str = None, target: T_Target = None
) -> list[Cookie]:
async with create_session() as sess:
query = select(Cookie).distinct().join(User)
if user:
user_id = await sess.scalar(select(User.id).where(User.user_target == model_dump(user)))
query = query.where(Cookie.user_id == user_id)
if platform_name:
query = query.where(Cookie.platform_name == platform_name)
query = query.outerjoin(CookieTarget).options(selectinload(Cookie.targets))
res = (await sess.scalars(query)).all()
if target:
query = select(CookieTarget.cookie_id).join(Target).where(Target.target == target)
ids = set((await sess.scalars(query)).all())
res = [cookie for cookie in res if cookie.id in ids]
return res
async def get_cookie_by_user_and_platform(self, user: PlatformTarget, platform_name: str) -> list[Cookie]:
async with create_session() as sess:
res = await sess.scalar(
select(User)
.where(User.user_target == model_dump(user))
.join(Cookie)
.where(Cookie.platform_name == platform_name)
.outerjoin(CookieTarget)
.options(selectinload(User.cookies))
)
@ -312,6 +331,12 @@ class DBConfig:
target_obj = await sess.scalar(
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
# check if relation exists
cookie_target = await sess.scalar(
select(CookieTarget).where(CookieTarget.target == target_obj, CookieTarget.cookie_id == cookie_id)
)
if cookie_target:
raise DuplicateCookieTargetException()
cookie_obj = await sess.scalar(select(Cookie).where(Cookie.id == cookie_id))
cookie_target = CookieTarget(target=target_obj, cookie=cookie_obj)
sess.add(cookie_target)

View File

@ -8,3 +8,7 @@ class NoSuchSubscribeException(Exception):
class NoSuchTargetException(Exception):
pass
class DuplicateCookieTargetException(Exception):
pass

View File

@ -15,6 +15,7 @@ from .add_sub import do_add_sub
from .del_sub import do_del_sub
from .query_sub import do_query_sub
from .add_cookie import do_add_cookie
from .add_cookie_target import do_add_cookie_target
from .utils import common_platform, admin_permission, gen_handle_cancel, configurable_to_me, set_target_user_info
add_sub_matcher = on_command(
@ -51,6 +52,16 @@ add_cookie_matcher = on_command(
add_cookie_matcher.handle()(set_target_user_info)
do_add_cookie(add_cookie_matcher)
add_cookie_target_matcher = on_command(
"关联cookie",
rule=configurable_to_me,
permission=admin_permission(),
priority=5,
block=True,
)
add_cookie_target_matcher.handle()(set_target_user_info)
do_add_cookie_target(add_cookie_target_matcher)
group_manage_matcher = on_command("群管理", rule=to_me(), permission=SUPERUSER, priority=4, block=True)
group_handle_cancel = gen_handle_cancel(group_manage_matcher, "已取消")

View File

@ -58,4 +58,6 @@ def do_add_cookie(add_cookie: type[Matcher]):
@add_cookie.handle()
async def add_cookie_process(state: T_State, user: PlatformTarget = Arg("target_user_info")):
await config.add_cookie(user, state["platform"], state["cookie"])
await add_cookie.finish(f"已添加 Cookie: {state['cookie']} 到平台 {state['platform']}")
await add_cookie.finish(
f"已添加 Cookie: {state['cookie']} 到平台 {state['platform']}" + "\n请使用“关联cookie”为 cookie 关联订阅"
)

View File

@ -0,0 +1,87 @@
from nonebot.typing import T_State
from nonebot.matcher import Matcher
from nonebot.params import Arg, ArgPlainText
from nonebot.internal.adapter import MessageTemplate
from nonebot_plugin_saa import MessageFactory, PlatformTarget
from ..config import config
from ..types import Category
from ..utils import parse_text
from ..platform import platform_manager
from ..apis import get_cookie_friendly_name
from .utils import ensure_user_info, gen_handle_cancel
def do_add_cookie_target(add_cookie_target_matcher: type[Matcher]):
handle_cancel = gen_handle_cancel(add_cookie_target_matcher, "已中止关联 cookie")
add_cookie_target_matcher.handle()(ensure_user_info(add_cookie_target_matcher))
@add_cookie_target_matcher.handle()
async def init_promote(state: T_State, user_info: PlatformTarget = Arg("target_user_info")):
sub_list = await config.list_subscribe(user_info)
if not sub_list:
await add_cookie_target_matcher.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅")
res = "订阅的帐号为:\n"
state["sub_table"] = {}
for index, sub in enumerate(sub_list, 1):
state["sub_table"][index] = {
"platform_name": sub.target.platform_name,
"target": sub.target.target,
}
res += f"{index} {sub.target.platform_name} {sub.target.target_name} {sub.target.target}\n"
if platform := platform_manager.get(sub.target.platform_name):
if platform.categories:
res += " [{}]".format(", ".join(platform.categories[Category(x)] for x in sub.categories))
if platform.enable_tag:
res += " {}".format(", ".join(sub.tags))
else:
res += f" (平台 {sub.target.platform_name} 已失效,请删除此订阅)"
if res[-1] != "\n":
res += "\n"
res += "请输入要关联 cookie 的订阅的序号\n输入'取消'中止"
await MessageFactory(await parse_text(res)).send()
@add_cookie_target_matcher.got("target_idx", parameterless=[handle_cancel])
async def got_target_idx(state: T_State, target_idx: str = ArgPlainText()):
try:
target_idx = int(target_idx)
state["target"] = state["sub_table"][target_idx]
except Exception:
await add_cookie_target_matcher.reject("序号错误")
@add_cookie_target_matcher.handle()
async def init_promote_cookie(state: T_State):
cookies = await config.get_cookie(
user=state["target_user_info"], platform_name=state["target"]["platform_name"]
)
associated_cookies = await config.get_cookie(
user=state["target_user_info"],
target=state["target"]["target"],
platform_name=state["target"]["platform_name"],
)
associated_cookie_ids = {cookie.id for cookie in associated_cookies}
cookies = [cookie for cookie in cookies if cookie.id not in associated_cookie_ids]
if not cookies:
await add_cookie_target_matcher.finish("当前平台暂无 Cookie请使用“添加cookie”命令添加")
state["cookies"] = cookies
state["_prompt"] = "请选择一个 Cookie已关联的 Cookie 不会显示\n" + "\n".join(
[f"{idx}. {await get_cookie_friendly_name(cookie)}" for idx, cookie in enumerate(cookies, 1)]
)
@add_cookie_target_matcher.got("cookie_idx", MessageTemplate("{_prompt}"), [handle_cancel])
async def got_cookie_idx(state: T_State, cookie_idx: str = ArgPlainText()):
try:
cookie_idx = int(cookie_idx)
state["cookie"] = state["cookies"][cookie_idx - 1]
except Exception:
await add_cookie_target_matcher.reject("序号错误")
@add_cookie_target_matcher.handle()
async def add_cookie_target_process(state: T_State, user: PlatformTarget = Arg("target_user_info")):
await config.add_cookie_target(state["target"]["target"], state["target"]["platform_name"], state["cookie"].id)
await add_cookie_target_matcher.finish(
f"已关联 Cookie: {await get_cookie_friendly_name(state['cookie'])} "
f"到订阅 {state['target']['platform_name']} {state['target']['target']}"
)

View File

@ -52,7 +52,7 @@ async def test_cookie_by_user(app: App, init_scheduler):
await config.add_cookie(TargetQQGroup(group_id=123), "weibo", "cookie")
cookies = await config.get_cookie_by_user(TargetQQGroup(group_id=123))
cookies = await config.get_cookie(TargetQQGroup(group_id=123))
cookie = cookies[0]
assert len(cookies) == 1
assert cookie.content == "cookie"
@ -65,7 +65,7 @@ async def test_cookie_by_user(app: App, init_scheduler):
cookie.status = "status1"
cookie.tags = {"tag1": "value1"}
await config.update_cookie(cookie)
cookies = await config.get_cookie_by_user(TargetQQGroup(group_id=123))
cookies = await config.get_cookie(TargetQQGroup(group_id=123))
assert len(cookies) == 1
assert cookies[0].content == cookie.content
@ -74,7 +74,7 @@ async def test_cookie_by_user(app: App, init_scheduler):
assert cookies[0].tags == cookie.tags
await config.delete_cookie(cookies[0].id)
cookies = await config.get_cookie_by_user(TargetQQGroup(group_id=123))
cookies = await config.get_cookie(TargetQQGroup(group_id=123))
assert len(cookies) == 0