💥 适配最新的 DataStore 插件,并切换模型为 SQLModel (#178)

* 使用 SQLModel

* 处理数据库迁移

* 与之前的模型相匹配

* sqlmodel 和 sqlalchemy 的导入移入测试函数内

并且使用 init_db 且测试前加载插件

* 重命名 alembic_version 表之前先检查是否存在且 version_num 属于插件

* 降级应该是把名称重新命名回去而不是删掉

* 不再设置 arbitrary_types_allowed 为 True
This commit is contained in:
uy/sun 2023-01-30 22:52:11 +08:00 committed by GitHub
parent 312847fb6a
commit 8da8f66fcf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1993 additions and 1871 deletions

3604
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -39,7 +39,7 @@ aiofiles = "^0.8.0"
python-socketio = "^5.4.0"
nonebot-adapter-onebot = "^2.0.0-beta.1"
nonebot-plugin-htmlrender = ">=0.2.0"
nonebot-plugin-datastore = "^0.4.0"
nonebot-plugin-datastore = "^0.5.2"
alembic = "^1.7.6"
[tool.poetry.dev-dependencies]
@ -78,3 +78,6 @@ extend-exclude = '''
profile = "black"
line_length = 88
skip_gitignore = true
[tool.nonebot]
plugins = ["src.plugins.nonebot_bison"]

View File

@ -1,18 +1,52 @@
from nonebot import get_driver
from nonebot.log import logger
from nonebot_plugin_datastore.db import get_engine, post_db_init, pre_db_init
from sqlalchemy import inspect, text
from .config.config_legacy import start_up as legacy_db_startup
from .config.db import upgrade_db
from .config.db import data_migrate
from .scheduler.aps import start_scheduler
from .scheduler.manager import init_scheduler
@get_driver().on_startup
async def bootstrap():
@pre_db_init
async def pre():
def _has_table(conn, table_name):
insp = inspect(conn)
return insp.has_table(table_name)
async with get_engine().begin() as conn:
if not await conn.run_sync(_has_table, "alembic_version"):
logger.debug("未发现默认版本数据库,开始初始化")
return
logger.debug("发现默认版本数据库,开始检查版本")
t = await conn.scalar(text("select version_num from alembic_version"))
if t not in [
"4a46ba54a3f3", # alter_type
"5f3370328e44", # add_time_weight_table
"0571870f5222", # init_db
"a333d6224193", # add_last_scheduled_time
"c97c445e2bdb", # add_constraint
]:
logger.warning(f"当前数据库版本:{t},不是插件的版本,已跳过。")
return
logger.debug(f"当前数据库版本:{t},是插件的版本,开始迁移。")
# 删除可能存在的版本数据库
if await conn.run_sync(_has_table, "nonebot_bison_alembic_version"):
await conn.execute(text("drop table nonebot_bison_alembic_version"))
await conn.execute(
text("alter table alembic_version rename to nonebot_bison_alembic_version")
)
@post_db_init
async def post():
# legacy db
legacy_db_startup()
# new db
await upgrade_db()
# migrate data
await data_migrate()
# init scheduler
await init_scheduler()
# start scheduler

View File

@ -1,3 +1,2 @@
from .db import DATA
from .db_config import config
from .utils import NoSuchSubscribeException, NoSuchTargetException, NoSuchUserException

View File

@ -1,18 +1,9 @@
from pathlib import Path
from alembic.config import Config
from alembic.runtime.environment import EnvironmentContext
from alembic.script.base import ScriptDirectory
from nonebot.log import logger
from nonebot_plugin_datastore import PluginData, db
from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.engine.base import Connection
from sqlalchemy.ext.asyncio.session import AsyncSession
from .config_legacy import ConfigContent, config, drop
from .db_model import Base, Subscribe, Target, User
DATA = PluginData("bison")
from .db_model import Subscribe, Target, User
async def data_migrate():
@ -77,33 +68,3 @@ async def data_migrate():
await sess.commit()
drop()
logger.info("migrate success")
async def upgrade_db():
alembic_cfg = Config()
alembic_cfg.set_main_option(
"script_location", str(Path(__file__).parent.joinpath("migrate"))
)
script = ScriptDirectory.from_config(alembic_cfg)
engine = db.get_engine()
env = EnvironmentContext(alembic_cfg, script)
def migrate_fun(revision, context):
return script._upgrade_revs("head", revision)
def do_run_migration(connection: Connection):
env.configure(
connection,
target_metadata=Base.metadata,
fn=migrate_fun,
render_as_batch=True,
)
with env.begin_transaction():
env.run_migrations()
logger.info("Finish auto migrate")
async with engine.connect() as connection:
await connection.run_sync(do_run_migration)
await data_migrate()

View File

@ -2,12 +2,10 @@ from collections import defaultdict
from datetime import datetime, time
from typing import Awaitable, Callable, Optional
from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_datastore import create_session
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.sql.expression import delete, select
from sqlalchemy.sql.functions import func
from sqlmodel import delete, func, select
from ..types import Category, PlatformWeightConfigResp, Tag
from ..types import Target as T_Target
@ -49,7 +47,7 @@ class DBConfig:
cats: list[Category],
tags: list[Tag],
):
async with AsyncSession(get_engine()) as session:
async with create_session() as session:
db_user_stmt = (
select(User).where(User.uid == user).where(User.type == user_type)
)
@ -86,7 +84,7 @@ class DBConfig:
raise e
async def list_subscribe(self, user: int, user_type: str) -> list[Subscribe]:
async with AsyncSession(get_engine()) as session:
async with create_session() as session:
query_stmt = (
select(Subscribe)
.where(User.type == user_type, User.uid == user)
@ -99,7 +97,7 @@ class DBConfig:
async def del_subscribe(
self, user: int, user_type: str, target: str, platform_name: str
):
async with AsyncSession(get_engine()) as session:
async with create_session() as session:
user_obj = await session.scalar(
select(User).where(User.uid == user, User.type == user_type)
)
@ -135,7 +133,7 @@ class DBConfig:
cats: list,
tags: list,
):
async with AsyncSession(get_engine()) as sess:
async with create_session() as sess:
subscribe_obj: Subscribe = await sess.scalar(
select(Subscribe)
.where(
@ -154,7 +152,7 @@ class DBConfig:
await sess.commit()
async def get_platform_target(self, platform_name: str) -> list[Target]:
async with AsyncSession(get_engine()) as sess:
async with create_session() as sess:
subq = select(Subscribe.target_id).distinct().subquery()
query = (
select(Target).join(subq).where(Target.platform_name == platform_name)
@ -164,7 +162,7 @@ class DBConfig:
async def get_time_weight_config(
self, target: T_Target, platform_name: str
) -> WeightConfig:
async with AsyncSession(get_engine()) as sess:
async with create_session() as sess:
time_weight_conf: list[ScheduleTimeWeight] = (
await sess.scalars(
select(ScheduleTimeWeight)
@ -194,7 +192,7 @@ class DBConfig:
async def update_time_weight_config(
self, target: T_Target, platform_name: str, conf: WeightConfig
):
async with AsyncSession(get_engine()) as sess:
async with create_session() as sess:
targetObj: Target = await sess.scalar(
select(Target).where(
Target.platform_name == platform_name, Target.target == target
@ -222,7 +220,7 @@ class DBConfig:
async def get_current_weight_val(self, platform_list: list[str]) -> dict[str, int]:
res = {}
cur_time = _get_time()
async with AsyncSession(get_engine()) as sess:
async with create_session() as sess:
targets: list[Target] = (
await sess.scalars(
select(Target)
@ -246,7 +244,7 @@ class DBConfig:
async def get_platform_target_subscribers(
self, platform_name: str, target: T_Target
) -> list[UserSubInfo]:
async with AsyncSession(get_engine()) as sess:
async with create_session() as sess:
query = (
select(Subscribe)
.join(Target)
@ -269,7 +267,7 @@ class DBConfig:
self,
) -> dict[str, dict[str, PlatformWeightConfigResp]]:
res: dict[str, dict[str, PlatformWeightConfigResp]] = defaultdict(dict)
async with AsyncSession(get_engine()) as sess:
async with create_session() as sess:
query = select(Target)
targets: list[Target] = (await sess.scalars(query)).all()
query = select(ScheduleTimeWeight).options(

View File

@ -1,61 +1,68 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.sql.schema import Column, ForeignKey, UniqueConstraint
from sqlalchemy.sql.sqltypes import JSON, Integer, String, Time
import datetime
from pathlib import Path
from typing import Optional
Base = declarative_base()
from nonebot_plugin_datastore import get_plugin_data
from sqlmodel import JSON, Column, Field, Relationship, UniqueConstraint
from ..types import Category, Tag
Model = get_plugin_data().Model
get_plugin_data().set_migration_dir(Path(__file__).parent / "migrate" / "versions")
class User(Base):
__tablename__ = "user"
class User(Model, table=True):
__table_args__ = (UniqueConstraint("type", "uid", name="unique-user-constraint"),)
id = Column(Integer, primary_key=True, autoincrement=True)
type = Column(String(20), nullable=False)
uid = Column(Integer, nullable=False)
id: Optional[int] = Field(default=None, primary_key=True)
type: str = Field(max_length=20)
uid: int
subscribes = relationship("Subscribe", back_populates="user")
subscribes: list["Subscribe"] = Relationship(back_populates="user")
class Target(Base):
__tablename__ = "target"
class Target(Model, table=True):
__table_args__ = (
UniqueConstraint("target", "platform_name", name="unique-target-constraint"),
)
id = Column(Integer, primary_key=True, autoincrement=True)
platform_name = Column(String(20), nullable=False)
target = Column(String(1024), nullable=False)
target_name = Column(String(1024), nullable=False)
default_schedule_weight = Column(Integer, default=10)
id: Optional[int] = Field(default=None, primary_key=True)
platform_name: str = Field(max_length=20)
target: str = Field(max_length=1024)
target_name: str = Field(max_length=1024)
default_schedule_weight: Optional[int] = Field(default=10)
subscribes = relationship("Subscribe", back_populates="target")
time_weight = relationship("ScheduleTimeWeight", back_populates="target")
subscribes: list["Subscribe"] = Relationship(back_populates="target")
time_weight: list["ScheduleTimeWeight"] = Relationship(back_populates="target")
class ScheduleTimeWeight(Base):
__tablename__ = "schedule_time_weight"
class ScheduleTimeWeight(Model, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
target_id: Optional[int] = Field(
default=None, foreign_key="nonebot_bison_target.id"
)
start_time: Optional[datetime.time]
end_time: Optional[datetime.time]
weight: Optional[int]
id = Column(Integer, primary_key=True, autoincrement=True)
target_id = Column(Integer, ForeignKey(Target.id))
start_time = Column(Time)
end_time = Column(Time)
weight = Column(Integer)
target: Target = Relationship(back_populates="time_weight")
target = relationship("Target", back_populates="time_weight")
class Config:
arbitrary_types_allowed = True
class Subscribe(Base):
__tablename__ = "subscribe"
class Subscribe(Model, table=True):
__table_args__ = (
UniqueConstraint("target_id", "user_id", name="unique-subscribe-constraint"),
)
id = Column(Integer, primary_key=True, autoincrement=True)
target_id = Column(Integer, ForeignKey(Target.id))
user_id = Column(Integer, ForeignKey(User.id))
categories = Column(JSON)
tags = Column(JSON)
id: Optional[int] = Field(default=None, primary_key=True)
target_id: Optional[int] = Field(
default=None, foreign_key="nonebot_bison_target.id"
)
user_id: Optional[int] = Field(default=None, foreign_key="nonebot_bison_user.id")
categories: list[Category] = Field(sa_column=Column(JSON))
tags: list[Tag] = Field(sa_column=Column(JSON))
target = relationship("Target", back_populates="subscribes")
user = relationship("User", back_populates="subscribes")
target: Target = Relationship(back_populates="subscribes")
user: User = Relationship(back_populates="subscribes")

View File

@ -0,0 +1,34 @@
"""rename tables
Revision ID: 5da28f6facb3
Revises: 5f3370328e44
Create Date: 2023-01-15 19:04:54.987491
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
# revision identifiers, used by Alembic.
revision = "5da28f6facb3"
down_revision = "5f3370328e44"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.rename_table("target", "nonebot_bison_target")
op.rename_table("user", "nonebot_bison_user")
op.rename_table("schedule_time_weight", "nonebot_bison_scheduletimeweight")
op.rename_table("subscribe", "nonebot_bison_subscribe")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.rename_table("nonebot_bison_subscribe", "subscribe")
op.rename_table("nonebot_bison_scheduletimeweight", "schedule_time_weight")
op.rename_table("nonebot_bison_user", "user")
op.rename_table("nonebot_bison_target", "target")
# ### end Alembic commands ###

View File

@ -1,16 +1,14 @@
import pytest
from nonebug.app import App
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.functions import func
from sqlmodel.sql.expression import select
async def test_add_subscribe(app: App, init_scheduler):
from nonebot_bison.config.db_config import config
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as TTarget
from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import select
await config.add_subscribe(
user=123,
@ -74,7 +72,6 @@ async def test_add_subscribe(app: App, init_scheduler):
async def test_add_dup_sub(init_scheduler):
from nonebot_bison.config.db_config import SubscribeDupException, config
from nonebot_bison.types import Target as TTarget
@ -102,9 +99,12 @@ async def test_add_dup_sub(init_scheduler):
async def test_del_subsribe(init_scheduler):
from nonebot_bison.config.db_config import config
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.config.db_model import Subscribe, Target
from nonebot_bison.types import Target as TTarget
from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.functions import func
from sqlmodel.sql.expression import select
await config.add_subscribe(
user=123,

View File

@ -1,7 +1,7 @@
async def test_migration(use_legacy_config):
from nonebot_bison.config.config_legacy import config as config_legacy
from nonebot_bison.config.db import upgrade_db
from nonebot_bison.config.db_config import config
from nonebot_plugin_datastore.db import init_db
config_legacy.add_subscribe(
user=123,
@ -31,7 +31,7 @@ async def test_migration(use_legacy_config):
tags=[],
)
# await data_migrate()
await upgrade_db()
await init_db()
user123_config = await config.list_subscribe(123, "group")
assert len(user123_config) == 2
for c in user123_config:
@ -55,8 +55,8 @@ async def test_migration(use_legacy_config):
async def test_migrate_dup(use_legacy_config):
from nonebot_bison.config.config_legacy import config as config_legacy
from nonebot_bison.config.db import upgrade_db
from nonebot_bison.config.db_config import config
from nonebot_plugin_datastore.db import init_db
config_legacy.add_subscribe(
user=123,
@ -77,6 +77,6 @@ async def test_migrate_dup(use_legacy_config):
tags=[],
)
# await data_migrate()
await upgrade_db()
await init_db()
user123_config = await config.list_subscribe(123, "group")
assert len(user123_config) == 1

View File

@ -22,6 +22,7 @@ async def app(nonebug_init: None, tmp_path: Path, monkeypatch: pytest.MonkeyPatc
config.superusers = {"10001"}
config.log_level = "TRACE"
config.bison_filter_log = False
nonebot.require("nonebot_bison")
return App(monkeypatch)
@ -35,11 +36,10 @@ def dummy_user_subinfo(app: App):
@pytest.fixture
async def db_migration(app: App):
from nonebot_bison.config.db import upgrade_db
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_datastore.db import get_engine, init_db
await upgrade_db()
await init_db()
async with AsyncSession(get_engine()) as sess:
await sess.execute(delete(User))
await sess.execute(delete(Subscribe))