mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-05 19:36:43 +08:00
update
This commit is contained in:
parent
df23648b0f
commit
cf35432757
@ -11,12 +11,12 @@ repos:
|
||||
- id: isort
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.1.0
|
||||
rev: 22.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: v2.5.1
|
||||
rev: v2.6.1
|
||||
hooks:
|
||||
- id: prettier
|
||||
types_or: [markdown, ts, tsx]
|
||||
|
@ -1,2 +1,3 @@
|
||||
from .config_legacy import NoSuchSubscribeException, NoSuchUserException, config
|
||||
from .config_legacy import NoSuchSubscribeException, NoSuchUserException
|
||||
from .db import DATA
|
||||
from .db_config import config
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import nonebot
|
||||
@ -7,7 +6,9 @@ from alembic.runtime.environment import EnvironmentContext
|
||||
from alembic.script.base import ScriptDirectory
|
||||
from nonebot.log import logger
|
||||
from nonebot_plugin_datastore import PluginData, create_session, db
|
||||
from nonebot_plugin_datastore.db import get_engine
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
|
||||
from .config_legacy import ConfigContent, config
|
||||
from .db_model import Base, Subscribe, Target, User
|
||||
@ -21,53 +22,52 @@ async def data_migrate():
|
||||
all_subs: list[ConfigContent] = list(
|
||||
map(lambda item: ConfigContent(**item), config.get_all_subscribe().all())
|
||||
)
|
||||
print(all_subs)
|
||||
sess = create_session()
|
||||
user_to_create = []
|
||||
subscribe_to_create = []
|
||||
platform_target_map: dict[str, tuple[Target, str, int]] = {}
|
||||
for user in all_subs:
|
||||
db_user = User(uid=user["user"], type=user["user_type"])
|
||||
user_to_create.append(db_user)
|
||||
for sub in user["subs"]:
|
||||
target = sub["target"]
|
||||
platform_name = sub["target_type"]
|
||||
target_name = sub["target_name"]
|
||||
key = f"{target}-{platform_name}"
|
||||
if key in platform_target_map.keys():
|
||||
target_obj, ext_user_type, ext_user = platform_target_map[key]
|
||||
if target_obj.target_name != target_name:
|
||||
# GG
|
||||
logger.error(
|
||||
f"你的旧版本数据库中存在数据不一致问题,请完成迁移后执行重新添加{platform_name}平台的{target}"
|
||||
f"它的名字可能为{target_obj.target_name}或{target_name}"
|
||||
)
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
user_to_create = []
|
||||
subscribe_to_create = []
|
||||
platform_target_map: dict[str, tuple[Target, str, int]] = {}
|
||||
for user in all_subs:
|
||||
db_user = User(uid=user["user"], type=user["user_type"])
|
||||
user_to_create.append(db_user)
|
||||
for sub in user["subs"]:
|
||||
target = sub["target"]
|
||||
platform_name = sub["target_type"]
|
||||
target_name = sub["target_name"]
|
||||
key = f"{target}-{platform_name}"
|
||||
if key in platform_target_map.keys():
|
||||
target_obj, ext_user_type, ext_user = platform_target_map[key]
|
||||
if target_obj.target_name != target_name:
|
||||
# GG
|
||||
logger.error(
|
||||
f"你的旧版本数据库中存在数据不一致问题,请完成迁移后执行重新添加{platform_name}平台的{target}"
|
||||
f"它的名字可能为{target_obj.target_name}或{target_name}"
|
||||
)
|
||||
|
||||
else:
|
||||
target_obj = Target(
|
||||
platform_name=platform_name,
|
||||
target_name=target_name,
|
||||
target=target,
|
||||
else:
|
||||
target_obj = Target(
|
||||
platform_name=platform_name,
|
||||
target_name=target_name,
|
||||
target=target,
|
||||
)
|
||||
platform_target_map[key] = (
|
||||
target_obj,
|
||||
user["user_type"],
|
||||
user["user"],
|
||||
)
|
||||
subscribe_obj = Subscribe(
|
||||
user=db_user,
|
||||
target=target_obj,
|
||||
categories=sub["cats"],
|
||||
tags=sub["tags"],
|
||||
)
|
||||
platform_target_map[key] = (
|
||||
target_obj,
|
||||
user["user_type"],
|
||||
user["user"],
|
||||
)
|
||||
subscribe_obj = Subscribe(
|
||||
user=db_user,
|
||||
target=target_obj,
|
||||
categories=json.dumps(sub["cats"]),
|
||||
tags=json.dumps(sub["tags"]),
|
||||
)
|
||||
subscribe_to_create.append(subscribe_obj)
|
||||
sess.add_all(
|
||||
user_to_create
|
||||
+ list(map(lambda x: x[0], platform_target_map.values()))
|
||||
+ subscribe_to_create
|
||||
)
|
||||
await sess.commit()
|
||||
logger.info("migrate success")
|
||||
subscribe_to_create.append(subscribe_obj)
|
||||
sess.add_all(
|
||||
user_to_create
|
||||
+ list(map(lambda x: x[0], platform_target_map.values()))
|
||||
+ subscribe_to_create
|
||||
)
|
||||
await sess.commit()
|
||||
logger.info("migrate success")
|
||||
|
||||
|
||||
@nonebot.get_driver().on_startup
|
||||
@ -85,7 +85,12 @@ async def upgrade_db():
|
||||
return script._upgrade_revs("head", revision)
|
||||
|
||||
def do_run_migration(connection: Connection):
|
||||
env.configure(connection, target_metadata=Base.metadata, fn=migrate_fun)
|
||||
env.configure(
|
||||
connection,
|
||||
target_metadata=Base.metadata,
|
||||
fn=migrate_fun,
|
||||
render_as_batch=True,
|
||||
)
|
||||
with env.begin_transaction():
|
||||
env.run_migrations()
|
||||
logger.info("Finish auto migrate")
|
||||
|
@ -1,9 +1,11 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from nonebot_bison.types import Category, Tag, Target
|
||||
from nonebot_plugin_datastore.db import create_session
|
||||
from sqlalchemy.sql.expression import select
|
||||
from nonebot_plugin_datastore.db import get_engine
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.sql.expression import delete, select
|
||||
from sqlalchemy.sql.functions import func
|
||||
|
||||
from .db_model import Subscribe as MSubscribe
|
||||
from .db_model import Target as MTarget
|
||||
@ -11,9 +13,6 @@ from .db_model import User
|
||||
|
||||
|
||||
class DBConfig:
|
||||
def __init__(self):
|
||||
self.session = create_session()
|
||||
|
||||
async def add_subscribe(
|
||||
self,
|
||||
user: int,
|
||||
@ -24,35 +23,98 @@ class DBConfig:
|
||||
cats: list[Category],
|
||||
tags: list[Tag],
|
||||
):
|
||||
db_user_stmt = (
|
||||
select(User).where(User.uid == user).where(User.type == user_type)
|
||||
)
|
||||
db_user: Optional[User] = (await self.session.scalars(db_user_stmt)).first()
|
||||
if not db_user:
|
||||
db_user = User(uid=user, type=user_type)
|
||||
self.session.add(db_user)
|
||||
db_target_stmt = (
|
||||
select(MTarget)
|
||||
.where(MTarget.platform_name == platform_name)
|
||||
.where(MTarget.target == target)
|
||||
)
|
||||
db_target: Optional[MTarget] = (
|
||||
await self.session.scalars(db_target_stmt)
|
||||
).first()
|
||||
if not db_target:
|
||||
db_target = MTarget(
|
||||
target=target, platform_name=platform_name, target_name=target_name
|
||||
async with AsyncSession(get_engine()) as session:
|
||||
db_user_stmt = (
|
||||
select(User).where(User.uid == user).where(User.type == user_type)
|
||||
)
|
||||
else:
|
||||
db_target.target_name = target_name # type: ignore
|
||||
subscribe = MSubscribe(
|
||||
categories=json.dumps(cats),
|
||||
tags=json.dumps(tags),
|
||||
user=db_user,
|
||||
target=db_target,
|
||||
)
|
||||
self.session.add(subscribe)
|
||||
await self.session.commit()
|
||||
db_user: Optional[User] = await session.scalar(db_user_stmt)
|
||||
if not db_user:
|
||||
db_user = User(uid=user, type=user_type)
|
||||
session.add(db_user)
|
||||
db_target_stmt = (
|
||||
select(MTarget)
|
||||
.where(MTarget.platform_name == platform_name)
|
||||
.where(MTarget.target == target)
|
||||
)
|
||||
db_target: Optional[MTarget] = await session.scalar(db_target_stmt)
|
||||
if not db_target:
|
||||
db_target = MTarget(
|
||||
target=target, platform_name=platform_name, target_name=target_name
|
||||
)
|
||||
else:
|
||||
db_target.target_name = target_name # type: ignore
|
||||
subscribe = MSubscribe(
|
||||
categories=cats,
|
||||
tags=tags,
|
||||
user=db_user,
|
||||
target=db_target,
|
||||
)
|
||||
session.add(subscribe)
|
||||
await session.commit()
|
||||
|
||||
async def list_subscribe(self, user: int, user_type: str) -> list[MSubscribe]:
|
||||
async with AsyncSession(get_engine()) as session:
|
||||
query_stmt = (
|
||||
select(MSubscribe)
|
||||
.where(User.type == user_type and User.uid == user)
|
||||
.join(User)
|
||||
.options(selectinload(MSubscribe.target))
|
||||
) # type:ignore
|
||||
subs: list[MSubscribe] = (await session.scalars(query_stmt)).all()
|
||||
return subs
|
||||
|
||||
async def del_subscribe(
|
||||
self, user: int, user_type: str, target: str, platform_name: str
|
||||
):
|
||||
async with AsyncSession(get_engine()) as session:
|
||||
user_obj = await session.scalar(
|
||||
select(User).where(User.uid == user and User.type == user_type)
|
||||
)
|
||||
target_obj = await session.scalar(
|
||||
select(MTarget).where(
|
||||
MTarget.platform_name == platform_name and MTarget.target == target
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
delete(MSubscribe).where(
|
||||
MSubscribe.user == user_obj and MSubscribe.target == target_obj
|
||||
)
|
||||
)
|
||||
target_count = await session.scalar(
|
||||
select(func.count())
|
||||
.select_from(MSubscribe)
|
||||
.where(MSubscribe.target == target_obj)
|
||||
)
|
||||
if target_count == 0:
|
||||
# delete empty target
|
||||
await session.delete(target_obj)
|
||||
await session.commit()
|
||||
|
||||
async def update_subscribe(
|
||||
self,
|
||||
user: int,
|
||||
user_type: str,
|
||||
target: str,
|
||||
target_name: str,
|
||||
platform_name: str,
|
||||
cats: list,
|
||||
tags: list,
|
||||
):
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
subscribe_obj: MSubscribe = await sess.scalar(
|
||||
select(MSubscribe)
|
||||
.where(
|
||||
User.uid == user
|
||||
and User.type == user_type
|
||||
and MTarget.target == target
|
||||
and MTarget.platform_name == platform_name
|
||||
)
|
||||
.join(User)
|
||||
.join(MTarget)
|
||||
)
|
||||
subscribe_obj.tags = tags # type:ignore
|
||||
subscribe_obj.categories = cats # type:ignore
|
||||
await sess.commit()
|
||||
|
||||
|
||||
config = DBConfig()
|
||||
|
@ -1,35 +1,52 @@
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
from sqlalchemy.sql.schema import Column, ForeignKey
|
||||
from sqlalchemy.sql.sqltypes import Integer, String
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql.schema import Column, ForeignKey, UniqueConstraint
|
||||
from sqlalchemy.sql.sqltypes import JSON, DateTime, Integer, String
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "user"
|
||||
__table_args__ = (UniqueConstraint("type", "uid", name="unique-user-constraint"),)
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
type = Column(String(20), nullable=False)
|
||||
uid = Column(Integer, nullable=False)
|
||||
|
||||
subscribes = relationship("Subscribe", back_populates="user")
|
||||
|
||||
|
||||
class Target(Base):
|
||||
__tablename__ = "target"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("target", "platform_name", name="unique-target-constraint"),
|
||||
)
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
platform_name = Column(String(20), nullable=False)
|
||||
target = Column(String(1024), nullable=False)
|
||||
target_name = Column(String(1024), nullable=False)
|
||||
last_schedule_time = Column(
|
||||
DateTime(timezone=True), default=datetime(year=2000, month=1, day=1)
|
||||
)
|
||||
|
||||
subscribes = relationship("Subscribe", back_populates="target")
|
||||
|
||||
|
||||
class Subscribe(Base):
|
||||
__tablename__ = "subscribe"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("target_id", "user_id", name="unique-subscribe-constraint"),
|
||||
)
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
target_id = Column(Integer, ForeignKey(Target.id))
|
||||
user_id = Column(Integer, ForeignKey(User.id))
|
||||
categories = Column(String(1024))
|
||||
tags = Column(String(1024))
|
||||
categories = Column(JSON)
|
||||
tags = Column(JSON)
|
||||
|
||||
target = relationship("Target")
|
||||
user = relationship("User")
|
||||
target = relationship("Target", back_populates="subscribes")
|
||||
user = relationship("User", back_populates="subscribes")
|
||||
|
@ -66,7 +66,12 @@ def do_run_migration(connection: Connection):
|
||||
if __as_plugin:
|
||||
context.configure(connection=connection)
|
||||
else:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
render_as_batch=True,
|
||||
compare_type=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
@ -0,0 +1,53 @@
|
||||
"""alter type
|
||||
|
||||
Revision ID: 4a46ba54a3f3
|
||||
Revises: c97c445e2bdb
|
||||
Create Date: 2022-03-27 21:50:10.911649
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4a46ba54a3f3"
|
||||
down_revision = "c97c445e2bdb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("subscribe", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"categories",
|
||||
existing_type=sa.VARCHAR(length=1024),
|
||||
type_=sa.JSON(),
|
||||
existing_nullable=True,
|
||||
)
|
||||
batch_op.alter_column(
|
||||
"tags",
|
||||
existing_type=sa.VARCHAR(length=1024),
|
||||
type_=sa.JSON(),
|
||||
existing_nullable=True,
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("subscribe", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"tags",
|
||||
existing_type=sa.JSON(),
|
||||
type_=sa.VARCHAR(length=1024),
|
||||
existing_nullable=True,
|
||||
)
|
||||
batch_op.alter_column(
|
||||
"categories",
|
||||
existing_type=sa.JSON(),
|
||||
type_=sa.VARCHAR(length=1024),
|
||||
existing_nullable=True,
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
@ -0,0 +1,33 @@
|
||||
"""add last scheduled time
|
||||
|
||||
Revision ID: a333d6224193
|
||||
Revises: 4a46ba54a3f3
|
||||
Create Date: 2022-03-29 21:01:38.213153
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a333d6224193"
|
||||
down_revision = "4a46ba54a3f3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("target", schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column("last_schedule_time", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("target", schema=None) as batch_op:
|
||||
batch_op.drop_column("last_schedule_time")
|
||||
|
||||
# ### end Alembic commands ###
|
@ -0,0 +1,47 @@
|
||||
"""add constraint
|
||||
|
||||
Revision ID: c97c445e2bdb
|
||||
Revises: 0571870f5222
|
||||
Create Date: 2022-03-26 19:46:50.910721
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c97c445e2bdb"
|
||||
down_revision = "0571870f5222"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("subscribe", schema=None) as batch_op:
|
||||
batch_op.create_unique_constraint(
|
||||
"unique-subscribe-constraint", ["target_id", "user_id"]
|
||||
)
|
||||
|
||||
with op.batch_alter_table("target", schema=None) as batch_op:
|
||||
batch_op.create_unique_constraint(
|
||||
"unique-target-constraint", ["target", "platform_name"]
|
||||
)
|
||||
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"])
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.drop_constraint("unique-user-constraint", type_="unique")
|
||||
|
||||
with op.batch_alter_table("target", schema=None) as batch_op:
|
||||
batch_op.drop_constraint("unique-target-constraint", type_="unique")
|
||||
|
||||
with op.batch_alter_table("subscribe", schema=None) as batch_op:
|
||||
batch_op.drop_constraint("unique-subscribe-constraint", type_="unique")
|
||||
|
||||
# ### end Alembic commands ###
|
@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, cast
|
||||
|
||||
from nonebot import on_command
|
||||
from nonebot.adapters.onebot.v11 import Bot, Event, MessageEvent
|
||||
@ -188,16 +188,16 @@ def do_add_sub(add_sub: Type[Matcher]):
|
||||
|
||||
@add_sub.got("tags", _gen_prompt_template("{_prompt}"), [Depends(parser_tags)])
|
||||
async def add_sub_process(event: Event, state: T_State):
|
||||
user = state.get("target_user_info")
|
||||
user = cast(User, state.get("target_user_info"))
|
||||
assert isinstance(user, User)
|
||||
config.add_subscribe(
|
||||
await config.add_subscribe(
|
||||
# state.get("_user_id") or event.group_id,
|
||||
# user_type="group",
|
||||
user=user.user,
|
||||
user_type=user.user_type,
|
||||
target=state["id"],
|
||||
target_name=state["name"],
|
||||
target_type=state["platform"],
|
||||
platform_name=state["platform"],
|
||||
cats=state.get("cats", []),
|
||||
tags=state.get("tags", []),
|
||||
)
|
||||
@ -211,7 +211,7 @@ def do_query_sub(query_sub: Type[Matcher]):
|
||||
async def _(state: T_State):
|
||||
user_info = state["target_user_info"]
|
||||
assert isinstance(user_info, User)
|
||||
sub_list = config.list_subscribe(
|
||||
sub_list = await config.list_subscribe(
|
||||
# state.get("_user_id") or event.group_id, "group"
|
||||
user_info.user,
|
||||
user_info.user_type,
|
||||
@ -219,17 +219,20 @@ def do_query_sub(query_sub: Type[Matcher]):
|
||||
res = "订阅的帐号为:\n"
|
||||
for sub in sub_list:
|
||||
res += "{} {} {}".format(
|
||||
sub["target_type"], sub["target_name"], sub["target"]
|
||||
# sub["target_type"], sub["target_name"], sub["target"]
|
||||
sub.target.platform_name,
|
||||
sub.target.target_name,
|
||||
sub.target.target,
|
||||
)
|
||||
platform = platform_manager[sub["target_type"]]
|
||||
platform = platform_manager[sub.target.platform_name]
|
||||
if platform.categories:
|
||||
res += " [{}]".format(
|
||||
", ".join(
|
||||
map(lambda x: platform.categories[Category(x)], sub["cats"])
|
||||
map(lambda x: platform.categories[Category(x)], sub.categories)
|
||||
)
|
||||
)
|
||||
if platform.enable_tag:
|
||||
res += " {}".format(", ".join(sub["tags"]))
|
||||
res += " {}".format(", ".join(sub.tags))
|
||||
res += "\n"
|
||||
await query_sub.finish(Message(await parse_text(res)))
|
||||
|
||||
@ -241,7 +244,7 @@ def do_del_sub(del_sub: Type[Matcher]):
|
||||
async def send_list(bot: Bot, event: Event, state: T_State):
|
||||
user_info = state["target_user_info"]
|
||||
assert isinstance(user_info, User)
|
||||
sub_list = config.list_subscribe(
|
||||
sub_list = await config.list_subscribe(
|
||||
# state.get("_user_id") or event.group_id, "group"
|
||||
user_info.user,
|
||||
user_info.user_type,
|
||||
@ -250,21 +253,24 @@ def do_del_sub(del_sub: Type[Matcher]):
|
||||
state["sub_table"] = {}
|
||||
for index, sub in enumerate(sub_list, 1):
|
||||
state["sub_table"][index] = {
|
||||
"target_type": sub["target_type"],
|
||||
"target": sub["target"],
|
||||
"platform_name": sub.target.platform_name,
|
||||
"target": sub.target.target,
|
||||
}
|
||||
res += "{} {} {} {}\n".format(
|
||||
index, sub["target_type"], sub["target_name"], sub["target"]
|
||||
index,
|
||||
sub.target.platform_name,
|
||||
sub.target.target_name,
|
||||
sub.target.target,
|
||||
)
|
||||
platform = platform_manager[sub["target_type"]]
|
||||
platform = platform_manager[sub.target.platform_name]
|
||||
if platform.categories:
|
||||
res += " [{}]".format(
|
||||
", ".join(
|
||||
map(lambda x: platform.categories[Category(x)], sub["cats"])
|
||||
map(lambda x: platform.categories[Category(x)], sub.categories)
|
||||
)
|
||||
)
|
||||
if platform.enable_tag:
|
||||
res += " {}".format(", ".join(sub["tags"]))
|
||||
res += " {}".format(", ".join(sub.tags))
|
||||
res += "\n"
|
||||
res += "请输入要删除的订阅的序号"
|
||||
await bot.send(event=event, message=Message(await parse_text(res)))
|
||||
@ -275,7 +281,7 @@ def do_del_sub(del_sub: Type[Matcher]):
|
||||
index = int(str(event.get_message()).strip())
|
||||
user_info = state["target_user_info"]
|
||||
assert isinstance(user_info, User)
|
||||
config.del_subscribe(
|
||||
await config.del_subscribe(
|
||||
# state.get("_user_id") or event.group_id,
|
||||
# "group",
|
||||
user_info.user,
|
||||
|
@ -1,11 +1,11 @@
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Collection, Literal, Optional
|
||||
|
||||
import httpx
|
||||
from nonebot.log import logger
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from ..plugin_config import plugin_config
|
||||
from ..post import Post
|
||||
|
@ -1,4 +1,3 @@
|
||||
# from pydantic.dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from functools import reduce
|
||||
from io import BytesIO
|
||||
|
@ -1,7 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal, NamedTuple, NewType
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
RawPost = NewType("RawPost", Any)
|
||||
Target = NewType("Target", str)
|
||||
Category = int
|
||||
|
@ -1,17 +1,96 @@
|
||||
from nonebug.app import App
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.sql.functions import func
|
||||
from sqlmodel.sql.expression import select
|
||||
|
||||
|
||||
async def test_add_subscrib(app: App):
|
||||
async def test_add_subscribe(app: App, db_migration):
|
||||
|
||||
from nonebot_bison.config.db_config import config
|
||||
from nonebot_bison.types import Target
|
||||
from nonebot_bison.config.db_model import Subscribe, Target, User
|
||||
from nonebot_bison.types import Target as TTarget
|
||||
from nonebot_plugin_datastore.db import get_engine
|
||||
|
||||
await config.add_subscribe(
|
||||
user=123,
|
||||
user_type="group",
|
||||
target=Target("weibo_id"),
|
||||
target=TTarget("weibo_id"),
|
||||
target_name="weibo_name",
|
||||
platform_name="weibo",
|
||||
cats=[],
|
||||
tags=[],
|
||||
)
|
||||
confs = await config.list_subscribe(123, "group")
|
||||
assert len(confs) == 1
|
||||
conf: Subscribe = confs[0]
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
related_user_obj = await sess.scalar(
|
||||
select(User).where(User.id == conf.user_id)
|
||||
)
|
||||
related_target_obj = await sess.scalar(
|
||||
select(Target).where(Target.id == conf.target_id)
|
||||
)
|
||||
assert related_user_obj.uid == 123
|
||||
assert related_target_obj.target_name == "weibo_name"
|
||||
assert related_target_obj.target == "weibo_id"
|
||||
assert conf.target.target == "weibo_id"
|
||||
assert conf.categories == []
|
||||
|
||||
|
||||
async def test_del_subsribe(db_migration):
|
||||
from nonebot_bison.config.db_config import config
|
||||
from nonebot_bison.config.db_model import Subscribe, Target, User
|
||||
from nonebot_bison.types import Target as TTarget
|
||||
from nonebot_plugin_datastore.db import get_engine
|
||||
|
||||
await config.add_subscribe(
|
||||
user=123,
|
||||
user_type="group",
|
||||
target=TTarget("weibo_id"),
|
||||
target_name="weibo_name",
|
||||
platform_name="weibo",
|
||||
cats=[],
|
||||
tags=[],
|
||||
)
|
||||
await config.del_subscribe(
|
||||
user=123,
|
||||
user_type="group",
|
||||
target=TTarget("weibo_id"),
|
||||
platform_name="weibo",
|
||||
)
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
assert (await sess.scalar(select(func.count()).select_from(Subscribe))) == 0
|
||||
assert (await sess.scalar(select(func.count()).select_from(Target))) == 0
|
||||
|
||||
await config.add_subscribe(
|
||||
user=123,
|
||||
user_type="group",
|
||||
target=TTarget("weibo_id"),
|
||||
target_name="weibo_name",
|
||||
platform_name="weibo",
|
||||
cats=[],
|
||||
tags=[],
|
||||
)
|
||||
|
||||
await config.add_subscribe(
|
||||
user=124,
|
||||
user_type="group",
|
||||
target=TTarget("weibo_id"),
|
||||
target_name="weibo_name_new",
|
||||
platform_name="weibo",
|
||||
cats=[],
|
||||
tags=[],
|
||||
)
|
||||
|
||||
await config.del_subscribe(
|
||||
user=123,
|
||||
user_type="group",
|
||||
target=TTarget("weibo_id"),
|
||||
platform_name="weibo",
|
||||
)
|
||||
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
assert (await sess.scalar(select(func.count()).select_from(Subscribe))) == 1
|
||||
assert (await sess.scalar(select(func.count()).select_from(Target))) == 1
|
||||
target: Target = await sess.scalar(select(Target))
|
||||
assert target.target_name == "weibo_name_new"
|
||||
|
@ -5,6 +5,8 @@ from pathlib import Path
|
||||
import nonebot
|
||||
import pytest
|
||||
from nonebug.app import App
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.sql.expression import delete
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -12,7 +14,10 @@ async def app(nonebug_init: None, tmp_path: Path, monkeypatch: pytest.MonkeyPatc
|
||||
import nonebot
|
||||
|
||||
config = nonebot.get_driver().config
|
||||
config.bison_config_path = str(tmp_path)
|
||||
config.bison_config_path = str(tmp_path / "legacy_config")
|
||||
config.datastore_config_dir = str(tmp_path / "config")
|
||||
config.datastore_cache_dir = str(tmp_path / "cache")
|
||||
config.datastore_data_dir = str(tmp_path / "data")
|
||||
config.command_start = {""}
|
||||
config.superusers = {"10001"}
|
||||
config.log_level = "TRACE"
|
||||
@ -29,19 +34,15 @@ def dummy_user_subinfo(app: App):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task_watchdog(request):
|
||||
def cancel_test_on_exception(task: asyncio.Task):
|
||||
def maybe_cancel_clbk(t: asyncio.Task):
|
||||
exception = t.exception()
|
||||
if exception is None:
|
||||
return
|
||||
async def db_migration(app: App):
|
||||
from nonebot_bison.config.db import upgrade_db
|
||||
from nonebot_bison.config.db_model import Subscribe, Target, User
|
||||
from nonebot_plugin_datastore.db import get_engine
|
||||
|
||||
for task in asyncio.all_tasks():
|
||||
coro = task.get_coro()
|
||||
if coro.__qualname__ == request.function.__qualname__:
|
||||
task.cancel()
|
||||
return
|
||||
|
||||
task.add_done_callback(maybe_cancel_clbk)
|
||||
|
||||
return cancel_test_on_exception
|
||||
await upgrade_db()
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
await sess.execute(delete(User))
|
||||
await sess.execute(delete(Subscribe))
|
||||
await sess.execute(delete(Target))
|
||||
await sess.commit()
|
||||
await sess.close()
|
||||
|
Loading…
x
Reference in New Issue
Block a user