This commit is contained in:
felinae98 2022-03-29 22:43:39 +08:00
parent df23648b0f
commit cf35432757
No known key found for this signature in database
GPG Key ID: 00C8B010587FF610
15 changed files with 439 additions and 132 deletions

View File

@ -11,12 +11,12 @@ repos:
- id: isort
- repo: https://github.com/psf/black
rev: 22.1.0
rev: 22.3.0
hooks:
- id: black
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v2.5.1
rev: v2.6.1
hooks:
- id: prettier
types_or: [markdown, ts, tsx]

View File

@ -1,2 +1,3 @@
from .config_legacy import NoSuchSubscribeException, NoSuchUserException, config
from .config_legacy import NoSuchSubscribeException, NoSuchUserException
from .db import DATA
from .db_config import config

View File

@ -1,4 +1,3 @@
import json
from pathlib import Path
import nonebot
@ -7,7 +6,9 @@ from alembic.runtime.environment import EnvironmentContext
from alembic.script.base import ScriptDirectory
from nonebot.log import logger
from nonebot_plugin_datastore import PluginData, create_session, db
from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.engine.base import Connection
from sqlalchemy.ext.asyncio.session import AsyncSession
from .config_legacy import ConfigContent, config
from .db_model import Base, Subscribe, Target, User
@ -21,53 +22,52 @@ async def data_migrate():
all_subs: list[ConfigContent] = list(
map(lambda item: ConfigContent(**item), config.get_all_subscribe().all())
)
print(all_subs)
sess = create_session()
user_to_create = []
subscribe_to_create = []
platform_target_map: dict[str, tuple[Target, str, int]] = {}
for user in all_subs:
db_user = User(uid=user["user"], type=user["user_type"])
user_to_create.append(db_user)
for sub in user["subs"]:
target = sub["target"]
platform_name = sub["target_type"]
target_name = sub["target_name"]
key = f"{target}-{platform_name}"
if key in platform_target_map.keys():
target_obj, ext_user_type, ext_user = platform_target_map[key]
if target_obj.target_name != target_name:
# GG
logger.error(
f"你的旧版本数据库中存在数据不一致问题,请完成迁移后执行重新添加{platform_name}平台的{target}"
f"它的名字可能为{target_obj.target_name}{target_name}"
)
async with AsyncSession(get_engine()) as sess:
user_to_create = []
subscribe_to_create = []
platform_target_map: dict[str, tuple[Target, str, int]] = {}
for user in all_subs:
db_user = User(uid=user["user"], type=user["user_type"])
user_to_create.append(db_user)
for sub in user["subs"]:
target = sub["target"]
platform_name = sub["target_type"]
target_name = sub["target_name"]
key = f"{target}-{platform_name}"
if key in platform_target_map.keys():
target_obj, ext_user_type, ext_user = platform_target_map[key]
if target_obj.target_name != target_name:
# GG
logger.error(
f"你的旧版本数据库中存在数据不一致问题,请完成迁移后执行重新添加{platform_name}平台的{target}"
f"它的名字可能为{target_obj.target_name}{target_name}"
)
else:
target_obj = Target(
platform_name=platform_name,
target_name=target_name,
target=target,
else:
target_obj = Target(
platform_name=platform_name,
target_name=target_name,
target=target,
)
platform_target_map[key] = (
target_obj,
user["user_type"],
user["user"],
)
subscribe_obj = Subscribe(
user=db_user,
target=target_obj,
categories=sub["cats"],
tags=sub["tags"],
)
platform_target_map[key] = (
target_obj,
user["user_type"],
user["user"],
)
subscribe_obj = Subscribe(
user=db_user,
target=target_obj,
categories=json.dumps(sub["cats"]),
tags=json.dumps(sub["tags"]),
)
subscribe_to_create.append(subscribe_obj)
sess.add_all(
user_to_create
+ list(map(lambda x: x[0], platform_target_map.values()))
+ subscribe_to_create
)
await sess.commit()
logger.info("migrate success")
subscribe_to_create.append(subscribe_obj)
sess.add_all(
user_to_create
+ list(map(lambda x: x[0], platform_target_map.values()))
+ subscribe_to_create
)
await sess.commit()
logger.info("migrate success")
@nonebot.get_driver().on_startup
@ -85,7 +85,12 @@ async def upgrade_db():
return script._upgrade_revs("head", revision)
def do_run_migration(connection: Connection):
env.configure(connection, target_metadata=Base.metadata, fn=migrate_fun)
env.configure(
connection,
target_metadata=Base.metadata,
fn=migrate_fun,
render_as_batch=True,
)
with env.begin_transaction():
env.run_migrations()
logger.info("Finish auto migrate")

View File

@ -1,9 +1,11 @@
import json
from typing import Optional
from nonebot_bison.types import Category, Tag, Target
from nonebot_plugin_datastore.db import create_session
from sqlalchemy.sql.expression import select
from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.sql.expression import delete, select
from sqlalchemy.sql.functions import func
from .db_model import Subscribe as MSubscribe
from .db_model import Target as MTarget
@ -11,9 +13,6 @@ from .db_model import User
class DBConfig:
def __init__(self):
self.session = create_session()
async def add_subscribe(
self,
user: int,
@ -24,35 +23,98 @@ class DBConfig:
cats: list[Category],
tags: list[Tag],
):
db_user_stmt = (
select(User).where(User.uid == user).where(User.type == user_type)
)
db_user: Optional[User] = (await self.session.scalars(db_user_stmt)).first()
if not db_user:
db_user = User(uid=user, type=user_type)
self.session.add(db_user)
db_target_stmt = (
select(MTarget)
.where(MTarget.platform_name == platform_name)
.where(MTarget.target == target)
)
db_target: Optional[MTarget] = (
await self.session.scalars(db_target_stmt)
).first()
if not db_target:
db_target = MTarget(
target=target, platform_name=platform_name, target_name=target_name
async with AsyncSession(get_engine()) as session:
db_user_stmt = (
select(User).where(User.uid == user).where(User.type == user_type)
)
else:
db_target.target_name = target_name # type: ignore
subscribe = MSubscribe(
categories=json.dumps(cats),
tags=json.dumps(tags),
user=db_user,
target=db_target,
)
self.session.add(subscribe)
await self.session.commit()
db_user: Optional[User] = await session.scalar(db_user_stmt)
if not db_user:
db_user = User(uid=user, type=user_type)
session.add(db_user)
db_target_stmt = (
select(MTarget)
.where(MTarget.platform_name == platform_name)
.where(MTarget.target == target)
)
db_target: Optional[MTarget] = await session.scalar(db_target_stmt)
if not db_target:
db_target = MTarget(
target=target, platform_name=platform_name, target_name=target_name
)
else:
db_target.target_name = target_name # type: ignore
subscribe = MSubscribe(
categories=cats,
tags=tags,
user=db_user,
target=db_target,
)
session.add(subscribe)
await session.commit()
async def list_subscribe(self, user: int, user_type: str) -> list[MSubscribe]:
async with AsyncSession(get_engine()) as session:
query_stmt = (
select(MSubscribe)
.where(User.type == user_type and User.uid == user)
.join(User)
.options(selectinload(MSubscribe.target))
) # type:ignore
subs: list[MSubscribe] = (await session.scalars(query_stmt)).all()
return subs
async def del_subscribe(
self, user: int, user_type: str, target: str, platform_name: str
):
async with AsyncSession(get_engine()) as session:
user_obj = await session.scalar(
select(User).where(User.uid == user and User.type == user_type)
)
target_obj = await session.scalar(
select(MTarget).where(
MTarget.platform_name == platform_name and MTarget.target == target
)
)
await session.execute(
delete(MSubscribe).where(
MSubscribe.user == user_obj and MSubscribe.target == target_obj
)
)
target_count = await session.scalar(
select(func.count())
.select_from(MSubscribe)
.where(MSubscribe.target == target_obj)
)
if target_count == 0:
# delete empty target
await session.delete(target_obj)
await session.commit()
async def update_subscribe(
self,
user: int,
user_type: str,
target: str,
target_name: str,
platform_name: str,
cats: list,
tags: list,
):
async with AsyncSession(get_engine()) as sess:
subscribe_obj: MSubscribe = await sess.scalar(
select(MSubscribe)
.where(
User.uid == user
and User.type == user_type
and MTarget.target == target
and MTarget.platform_name == platform_name
)
.join(User)
.join(MTarget)
)
subscribe_obj.tags = tags # type:ignore
subscribe_obj.categories = cats # type:ignore
await sess.commit()
config = DBConfig()

View File

@ -1,35 +1,52 @@
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.sql.schema import Column, ForeignKey
from sqlalchemy.sql.sqltypes import Integer, String
from datetime import datetime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.sql.schema import Column, ForeignKey, UniqueConstraint
from sqlalchemy.sql.sqltypes import JSON, DateTime, Integer, String
Base = declarative_base()
class User(Base):
__tablename__ = "user"
__table_args__ = (UniqueConstraint("type", "uid", name="unique-user-constraint"),)
id = Column(Integer, primary_key=True, autoincrement=True)
type = Column(String(20), nullable=False)
uid = Column(Integer, nullable=False)
subscribes = relationship("Subscribe", back_populates="user")
class Target(Base):
__tablename__ = "target"
__table_args__ = (
UniqueConstraint("target", "platform_name", name="unique-target-constraint"),
)
id = Column(Integer, primary_key=True, autoincrement=True)
platform_name = Column(String(20), nullable=False)
target = Column(String(1024), nullable=False)
target_name = Column(String(1024), nullable=False)
last_schedule_time = Column(
DateTime(timezone=True), default=datetime(year=2000, month=1, day=1)
)
subscribes = relationship("Subscribe", back_populates="target")
class Subscribe(Base):
__tablename__ = "subscribe"
__table_args__ = (
UniqueConstraint("target_id", "user_id", name="unique-subscribe-constraint"),
)
id = Column(Integer, primary_key=True, autoincrement=True)
target_id = Column(Integer, ForeignKey(Target.id))
user_id = Column(Integer, ForeignKey(User.id))
categories = Column(String(1024))
tags = Column(String(1024))
categories = Column(JSON)
tags = Column(JSON)
target = relationship("Target")
user = relationship("User")
target = relationship("Target", back_populates="subscribes")
user = relationship("User", back_populates="subscribes")

View File

@ -66,7 +66,12 @@ def do_run_migration(connection: Connection):
if __as_plugin:
context.configure(connection=connection)
else:
context.configure(connection=connection, target_metadata=target_metadata)
context.configure(
connection=connection,
target_metadata=target_metadata,
render_as_batch=True,
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()

View File

@ -0,0 +1,53 @@
"""alter type
Revision ID: 4a46ba54a3f3
Revises: c97c445e2bdb
Create Date: 2022-03-27 21:50:10.911649
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "4a46ba54a3f3"
down_revision = "c97c445e2bdb"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("subscribe", schema=None) as batch_op:
batch_op.alter_column(
"categories",
existing_type=sa.VARCHAR(length=1024),
type_=sa.JSON(),
existing_nullable=True,
)
batch_op.alter_column(
"tags",
existing_type=sa.VARCHAR(length=1024),
type_=sa.JSON(),
existing_nullable=True,
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("subscribe", schema=None) as batch_op:
batch_op.alter_column(
"tags",
existing_type=sa.JSON(),
type_=sa.VARCHAR(length=1024),
existing_nullable=True,
)
batch_op.alter_column(
"categories",
existing_type=sa.JSON(),
type_=sa.VARCHAR(length=1024),
existing_nullable=True,
)
# ### end Alembic commands ###

View File

@ -0,0 +1,33 @@
"""add last scheduled time
Revision ID: a333d6224193
Revises: 4a46ba54a3f3
Create Date: 2022-03-29 21:01:38.213153
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "a333d6224193"
down_revision = "4a46ba54a3f3"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("target", schema=None) as batch_op:
batch_op.add_column(
sa.Column("last_schedule_time", sa.DateTime(timezone=True), nullable=True)
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("target", schema=None) as batch_op:
batch_op.drop_column("last_schedule_time")
# ### end Alembic commands ###

View File

@ -0,0 +1,47 @@
"""add constraint
Revision ID: c97c445e2bdb
Revises: 0571870f5222
Create Date: 2022-03-26 19:46:50.910721
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "c97c445e2bdb"
down_revision = "0571870f5222"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("subscribe", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique-subscribe-constraint", ["target_id", "user_id"]
)
with op.batch_alter_table("target", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique-target-constraint", ["target", "platform_name"]
)
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.drop_constraint("unique-user-constraint", type_="unique")
with op.batch_alter_table("target", schema=None) as batch_op:
batch_op.drop_constraint("unique-target-constraint", type_="unique")
with op.batch_alter_table("subscribe", schema=None) as batch_op:
batch_op.drop_constraint("unique-subscribe-constraint", type_="unique")
# ### end Alembic commands ###

View File

@ -1,6 +1,6 @@
import asyncio
from datetime import datetime
from typing import Optional, Type
from typing import Optional, Type, cast
from nonebot import on_command
from nonebot.adapters.onebot.v11 import Bot, Event, MessageEvent
@ -188,16 +188,16 @@ def do_add_sub(add_sub: Type[Matcher]):
@add_sub.got("tags", _gen_prompt_template("{_prompt}"), [Depends(parser_tags)])
async def add_sub_process(event: Event, state: T_State):
user = state.get("target_user_info")
user = cast(User, state.get("target_user_info"))
assert isinstance(user, User)
config.add_subscribe(
await config.add_subscribe(
# state.get("_user_id") or event.group_id,
# user_type="group",
user=user.user,
user_type=user.user_type,
target=state["id"],
target_name=state["name"],
target_type=state["platform"],
platform_name=state["platform"],
cats=state.get("cats", []),
tags=state.get("tags", []),
)
@ -211,7 +211,7 @@ def do_query_sub(query_sub: Type[Matcher]):
async def _(state: T_State):
user_info = state["target_user_info"]
assert isinstance(user_info, User)
sub_list = config.list_subscribe(
sub_list = await config.list_subscribe(
# state.get("_user_id") or event.group_id, "group"
user_info.user,
user_info.user_type,
@ -219,17 +219,20 @@ def do_query_sub(query_sub: Type[Matcher]):
res = "订阅的帐号为:\n"
for sub in sub_list:
res += "{} {} {}".format(
sub["target_type"], sub["target_name"], sub["target"]
# sub["target_type"], sub["target_name"], sub["target"]
sub.target.platform_name,
sub.target.target_name,
sub.target.target,
)
platform = platform_manager[sub["target_type"]]
platform = platform_manager[sub.target.platform_name]
if platform.categories:
res += " [{}]".format(
", ".join(
map(lambda x: platform.categories[Category(x)], sub["cats"])
map(lambda x: platform.categories[Category(x)], sub.categories)
)
)
if platform.enable_tag:
res += " {}".format(", ".join(sub["tags"]))
res += " {}".format(", ".join(sub.tags))
res += "\n"
await query_sub.finish(Message(await parse_text(res)))
@ -241,7 +244,7 @@ def do_del_sub(del_sub: Type[Matcher]):
async def send_list(bot: Bot, event: Event, state: T_State):
user_info = state["target_user_info"]
assert isinstance(user_info, User)
sub_list = config.list_subscribe(
sub_list = await config.list_subscribe(
# state.get("_user_id") or event.group_id, "group"
user_info.user,
user_info.user_type,
@ -250,21 +253,24 @@ def do_del_sub(del_sub: Type[Matcher]):
state["sub_table"] = {}
for index, sub in enumerate(sub_list, 1):
state["sub_table"][index] = {
"target_type": sub["target_type"],
"target": sub["target"],
"platform_name": sub.target.platform_name,
"target": sub.target.target,
}
res += "{} {} {} {}\n".format(
index, sub["target_type"], sub["target_name"], sub["target"]
index,
sub.target.platform_name,
sub.target.target_name,
sub.target.target,
)
platform = platform_manager[sub["target_type"]]
platform = platform_manager[sub.target.platform_name]
if platform.categories:
res += " [{}]".format(
", ".join(
map(lambda x: platform.categories[Category(x)], sub["cats"])
map(lambda x: platform.categories[Category(x)], sub.categories)
)
)
if platform.enable_tag:
res += " {}".format(", ".join(sub["tags"]))
res += " {}".format(", ".join(sub.tags))
res += "\n"
res += "请输入要删除的订阅的序号"
await bot.send(event=event, message=Message(await parse_text(res)))
@ -275,7 +281,7 @@ def do_del_sub(del_sub: Type[Matcher]):
index = int(str(event.get_message()).strip())
user_info = state["target_user_info"]
assert isinstance(user_info, User)
config.del_subscribe(
await config.del_subscribe(
# state.get("_user_id") or event.group_id,
# "group",
user_info.user,

View File

@ -1,11 +1,11 @@
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Collection, Literal, Optional
import httpx
from nonebot.log import logger
from pydantic.dataclasses import dataclass
from ..plugin_config import plugin_config
from ..post import Post

View File

@ -1,4 +1,3 @@
# from pydantic.dataclasses import dataclass
from dataclasses import dataclass, field
from functools import reduce
from io import BytesIO

View File

@ -1,7 +1,6 @@
from dataclasses import dataclass
from typing import Any, Callable, Literal, NamedTuple, NewType
from pydantic.dataclasses import dataclass
RawPost = NewType("RawPost", Any)
Target = NewType("Target", str)
Category = int

View File

@ -1,17 +1,96 @@
from nonebug.app import App
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.functions import func
from sqlmodel.sql.expression import select
async def test_add_subscrib(app: App):
async def test_add_subscribe(app: App, db_migration):
from nonebot_bison.config.db_config import config
from nonebot_bison.types import Target
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as TTarget
from nonebot_plugin_datastore.db import get_engine
await config.add_subscribe(
user=123,
user_type="group",
target=Target("weibo_id"),
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
confs = await config.list_subscribe(123, "group")
assert len(confs) == 1
conf: Subscribe = confs[0]
async with AsyncSession(get_engine()) as sess:
related_user_obj = await sess.scalar(
select(User).where(User.id == conf.user_id)
)
related_target_obj = await sess.scalar(
select(Target).where(Target.id == conf.target_id)
)
assert related_user_obj.uid == 123
assert related_target_obj.target_name == "weibo_name"
assert related_target_obj.target == "weibo_id"
assert conf.target.target == "weibo_id"
assert conf.categories == []
async def test_del_subsribe(db_migration):
from nonebot_bison.config.db_config import config
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as TTarget
from nonebot_plugin_datastore.db import get_engine
await config.add_subscribe(
user=123,
user_type="group",
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
await config.del_subscribe(
user=123,
user_type="group",
target=TTarget("weibo_id"),
platform_name="weibo",
)
async with AsyncSession(get_engine()) as sess:
assert (await sess.scalar(select(func.count()).select_from(Subscribe))) == 0
assert (await sess.scalar(select(func.count()).select_from(Target))) == 0
await config.add_subscribe(
user=123,
user_type="group",
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
await config.add_subscribe(
user=124,
user_type="group",
target=TTarget("weibo_id"),
target_name="weibo_name_new",
platform_name="weibo",
cats=[],
tags=[],
)
await config.del_subscribe(
user=123,
user_type="group",
target=TTarget("weibo_id"),
platform_name="weibo",
)
async with AsyncSession(get_engine()) as sess:
assert (await sess.scalar(select(func.count()).select_from(Subscribe))) == 1
assert (await sess.scalar(select(func.count()).select_from(Target))) == 1
target: Target = await sess.scalar(select(Target))
assert target.target_name == "weibo_name_new"

View File

@ -5,6 +5,8 @@ from pathlib import Path
import nonebot
import pytest
from nonebug.app import App
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.expression import delete
@pytest.fixture
@ -12,7 +14,10 @@ async def app(nonebug_init: None, tmp_path: Path, monkeypatch: pytest.MonkeyPatc
import nonebot
config = nonebot.get_driver().config
config.bison_config_path = str(tmp_path)
config.bison_config_path = str(tmp_path / "legacy_config")
config.datastore_config_dir = str(tmp_path / "config")
config.datastore_cache_dir = str(tmp_path / "cache")
config.datastore_data_dir = str(tmp_path / "data")
config.command_start = {""}
config.superusers = {"10001"}
config.log_level = "TRACE"
@ -29,19 +34,15 @@ def dummy_user_subinfo(app: App):
@pytest.fixture
def task_watchdog(request):
def cancel_test_on_exception(task: asyncio.Task):
def maybe_cancel_clbk(t: asyncio.Task):
exception = t.exception()
if exception is None:
return
async def db_migration(app: App):
from nonebot_bison.config.db import upgrade_db
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_plugin_datastore.db import get_engine
for task in asyncio.all_tasks():
coro = task.get_coro()
if coro.__qualname__ == request.function.__qualname__:
task.cancel()
return
task.add_done_callback(maybe_cancel_clbk)
return cancel_test_on_exception
await upgrade_db()
async with AsyncSession(get_engine()) as sess:
await sess.execute(delete(User))
await sess.execute(delete(Subscribe))
await sess.execute(delete(Target))
await sess.commit()
await sess.close()