mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-04 02:26:11 +08:00
💥 适配最新的 DataStore 插件,并切换模型为 SQLModel (#178)
* 使用 SQLModel * 处理数据库迁移 * 与之前的模型相匹配 * sqlmodel 和 sqlalchemy 的导入移入测试函数内 并且使用 init_db 且测试前加载插件 * 重命名 alembic_version 表之前先检查是否存在且 version_num 属于插件 * 降级应该是把名称重新命名回去而不是删掉 * 不再设置 arbitrary_types_allowed 为 True
This commit is contained in:
parent
312847fb6a
commit
8da8f66fcf
3604
poetry.lock
generated
3604
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -39,7 +39,7 @@ aiofiles = "^0.8.0"
|
||||
python-socketio = "^5.4.0"
|
||||
nonebot-adapter-onebot = "^2.0.0-beta.1"
|
||||
nonebot-plugin-htmlrender = ">=0.2.0"
|
||||
nonebot-plugin-datastore = "^0.4.0"
|
||||
nonebot-plugin-datastore = "^0.5.2"
|
||||
alembic = "^1.7.6"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
@ -78,3 +78,6 @@ extend-exclude = '''
|
||||
profile = "black"
|
||||
line_length = 88
|
||||
skip_gitignore = true
|
||||
|
||||
[tool.nonebot]
|
||||
plugins = ["src.plugins.nonebot_bison"]
|
@ -1,18 +1,52 @@
|
||||
from nonebot import get_driver
|
||||
from nonebot.log import logger
|
||||
from nonebot_plugin_datastore.db import get_engine, post_db_init, pre_db_init
|
||||
from sqlalchemy import inspect, text
|
||||
|
||||
from .config.config_legacy import start_up as legacy_db_startup
|
||||
from .config.db import upgrade_db
|
||||
from .config.db import data_migrate
|
||||
from .scheduler.aps import start_scheduler
|
||||
from .scheduler.manager import init_scheduler
|
||||
|
||||
|
||||
@get_driver().on_startup
|
||||
async def bootstrap():
|
||||
@pre_db_init
|
||||
async def pre():
|
||||
def _has_table(conn, table_name):
|
||||
insp = inspect(conn)
|
||||
return insp.has_table(table_name)
|
||||
|
||||
async with get_engine().begin() as conn:
|
||||
if not await conn.run_sync(_has_table, "alembic_version"):
|
||||
logger.debug("未发现默认版本数据库,开始初始化")
|
||||
return
|
||||
|
||||
logger.debug("发现默认版本数据库,开始检查版本")
|
||||
t = await conn.scalar(text("select version_num from alembic_version"))
|
||||
if t not in [
|
||||
"4a46ba54a3f3", # alter_type
|
||||
"5f3370328e44", # add_time_weight_table
|
||||
"0571870f5222", # init_db
|
||||
"a333d6224193", # add_last_scheduled_time
|
||||
"c97c445e2bdb", # add_constraint
|
||||
]:
|
||||
logger.warning(f"当前数据库版本:{t},不是插件的版本,已跳过。")
|
||||
return
|
||||
|
||||
logger.debug(f"当前数据库版本:{t},是插件的版本,开始迁移。")
|
||||
# 删除可能存在的版本数据库
|
||||
if await conn.run_sync(_has_table, "nonebot_bison_alembic_version"):
|
||||
await conn.execute(text("drop table nonebot_bison_alembic_version"))
|
||||
|
||||
await conn.execute(
|
||||
text("alter table alembic_version rename to nonebot_bison_alembic_version")
|
||||
)
|
||||
|
||||
|
||||
@post_db_init
|
||||
async def post():
|
||||
# legacy db
|
||||
legacy_db_startup()
|
||||
# new db
|
||||
await upgrade_db()
|
||||
# migrate data
|
||||
await data_migrate()
|
||||
# init scheduler
|
||||
await init_scheduler()
|
||||
# start scheduler
|
||||
|
@ -1,3 +1,2 @@
|
||||
from .db import DATA
|
||||
from .db_config import config
|
||||
from .utils import NoSuchSubscribeException, NoSuchTargetException, NoSuchUserException
|
||||
|
@ -1,18 +1,9 @@
|
||||
from pathlib import Path
|
||||
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.environment import EnvironmentContext
|
||||
from alembic.script.base import ScriptDirectory
|
||||
from nonebot.log import logger
|
||||
from nonebot_plugin_datastore import PluginData, db
|
||||
from nonebot_plugin_datastore.db import get_engine
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
|
||||
from .config_legacy import ConfigContent, config, drop
|
||||
from .db_model import Base, Subscribe, Target, User
|
||||
|
||||
DATA = PluginData("bison")
|
||||
from .db_model import Subscribe, Target, User
|
||||
|
||||
|
||||
async def data_migrate():
|
||||
@ -77,33 +68,3 @@ async def data_migrate():
|
||||
await sess.commit()
|
||||
drop()
|
||||
logger.info("migrate success")
|
||||
|
||||
|
||||
async def upgrade_db():
|
||||
alembic_cfg = Config()
|
||||
alembic_cfg.set_main_option(
|
||||
"script_location", str(Path(__file__).parent.joinpath("migrate"))
|
||||
)
|
||||
|
||||
script = ScriptDirectory.from_config(alembic_cfg)
|
||||
engine = db.get_engine()
|
||||
env = EnvironmentContext(alembic_cfg, script)
|
||||
|
||||
def migrate_fun(revision, context):
|
||||
return script._upgrade_revs("head", revision)
|
||||
|
||||
def do_run_migration(connection: Connection):
|
||||
env.configure(
|
||||
connection,
|
||||
target_metadata=Base.metadata,
|
||||
fn=migrate_fun,
|
||||
render_as_batch=True,
|
||||
)
|
||||
with env.begin_transaction():
|
||||
env.run_migrations()
|
||||
logger.info("Finish auto migrate")
|
||||
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(do_run_migration)
|
||||
|
||||
await data_migrate()
|
||||
|
@ -2,12 +2,10 @@ from collections import defaultdict
|
||||
from datetime import datetime, time
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from nonebot_plugin_datastore.db import get_engine
|
||||
from nonebot_plugin_datastore import create_session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.sql.expression import delete, select
|
||||
from sqlalchemy.sql.functions import func
|
||||
from sqlmodel import delete, func, select
|
||||
|
||||
from ..types import Category, PlatformWeightConfigResp, Tag
|
||||
from ..types import Target as T_Target
|
||||
@ -49,7 +47,7 @@ class DBConfig:
|
||||
cats: list[Category],
|
||||
tags: list[Tag],
|
||||
):
|
||||
async with AsyncSession(get_engine()) as session:
|
||||
async with create_session() as session:
|
||||
db_user_stmt = (
|
||||
select(User).where(User.uid == user).where(User.type == user_type)
|
||||
)
|
||||
@ -86,7 +84,7 @@ class DBConfig:
|
||||
raise e
|
||||
|
||||
async def list_subscribe(self, user: int, user_type: str) -> list[Subscribe]:
|
||||
async with AsyncSession(get_engine()) as session:
|
||||
async with create_session() as session:
|
||||
query_stmt = (
|
||||
select(Subscribe)
|
||||
.where(User.type == user_type, User.uid == user)
|
||||
@ -99,7 +97,7 @@ class DBConfig:
|
||||
async def del_subscribe(
|
||||
self, user: int, user_type: str, target: str, platform_name: str
|
||||
):
|
||||
async with AsyncSession(get_engine()) as session:
|
||||
async with create_session() as session:
|
||||
user_obj = await session.scalar(
|
||||
select(User).where(User.uid == user, User.type == user_type)
|
||||
)
|
||||
@ -135,7 +133,7 @@ class DBConfig:
|
||||
cats: list,
|
||||
tags: list,
|
||||
):
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
async with create_session() as sess:
|
||||
subscribe_obj: Subscribe = await sess.scalar(
|
||||
select(Subscribe)
|
||||
.where(
|
||||
@ -154,7 +152,7 @@ class DBConfig:
|
||||
await sess.commit()
|
||||
|
||||
async def get_platform_target(self, platform_name: str) -> list[Target]:
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
async with create_session() as sess:
|
||||
subq = select(Subscribe.target_id).distinct().subquery()
|
||||
query = (
|
||||
select(Target).join(subq).where(Target.platform_name == platform_name)
|
||||
@ -164,7 +162,7 @@ class DBConfig:
|
||||
async def get_time_weight_config(
|
||||
self, target: T_Target, platform_name: str
|
||||
) -> WeightConfig:
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
async with create_session() as sess:
|
||||
time_weight_conf: list[ScheduleTimeWeight] = (
|
||||
await sess.scalars(
|
||||
select(ScheduleTimeWeight)
|
||||
@ -194,7 +192,7 @@ class DBConfig:
|
||||
async def update_time_weight_config(
|
||||
self, target: T_Target, platform_name: str, conf: WeightConfig
|
||||
):
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
async with create_session() as sess:
|
||||
targetObj: Target = await sess.scalar(
|
||||
select(Target).where(
|
||||
Target.platform_name == platform_name, Target.target == target
|
||||
@ -222,7 +220,7 @@ class DBConfig:
|
||||
async def get_current_weight_val(self, platform_list: list[str]) -> dict[str, int]:
|
||||
res = {}
|
||||
cur_time = _get_time()
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
async with create_session() as sess:
|
||||
targets: list[Target] = (
|
||||
await sess.scalars(
|
||||
select(Target)
|
||||
@ -246,7 +244,7 @@ class DBConfig:
|
||||
async def get_platform_target_subscribers(
|
||||
self, platform_name: str, target: T_Target
|
||||
) -> list[UserSubInfo]:
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
async with create_session() as sess:
|
||||
query = (
|
||||
select(Subscribe)
|
||||
.join(Target)
|
||||
@ -269,7 +267,7 @@ class DBConfig:
|
||||
self,
|
||||
) -> dict[str, dict[str, PlatformWeightConfigResp]]:
|
||||
res: dict[str, dict[str, PlatformWeightConfigResp]] = defaultdict(dict)
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
async with create_session() as sess:
|
||||
query = select(Target)
|
||||
targets: list[Target] = (await sess.scalars(query)).all()
|
||||
query = select(ScheduleTimeWeight).options(
|
||||
|
@ -1,61 +1,68 @@
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql.schema import Column, ForeignKey, UniqueConstraint
|
||||
from sqlalchemy.sql.sqltypes import JSON, Integer, String, Time
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
Base = declarative_base()
|
||||
from nonebot_plugin_datastore import get_plugin_data
|
||||
from sqlmodel import JSON, Column, Field, Relationship, UniqueConstraint
|
||||
|
||||
from ..types import Category, Tag
|
||||
|
||||
Model = get_plugin_data().Model
|
||||
get_plugin_data().set_migration_dir(Path(__file__).parent / "migrate" / "versions")
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "user"
|
||||
class User(Model, table=True):
|
||||
__table_args__ = (UniqueConstraint("type", "uid", name="unique-user-constraint"),)
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
type = Column(String(20), nullable=False)
|
||||
uid = Column(Integer, nullable=False)
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
type: str = Field(max_length=20)
|
||||
uid: int
|
||||
|
||||
subscribes = relationship("Subscribe", back_populates="user")
|
||||
subscribes: list["Subscribe"] = Relationship(back_populates="user")
|
||||
|
||||
|
||||
class Target(Base):
|
||||
__tablename__ = "target"
|
||||
class Target(Model, table=True):
|
||||
__table_args__ = (
|
||||
UniqueConstraint("target", "platform_name", name="unique-target-constraint"),
|
||||
)
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
platform_name = Column(String(20), nullable=False)
|
||||
target = Column(String(1024), nullable=False)
|
||||
target_name = Column(String(1024), nullable=False)
|
||||
default_schedule_weight = Column(Integer, default=10)
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
platform_name: str = Field(max_length=20)
|
||||
target: str = Field(max_length=1024)
|
||||
target_name: str = Field(max_length=1024)
|
||||
default_schedule_weight: Optional[int] = Field(default=10)
|
||||
|
||||
subscribes = relationship("Subscribe", back_populates="target")
|
||||
time_weight = relationship("ScheduleTimeWeight", back_populates="target")
|
||||
subscribes: list["Subscribe"] = Relationship(back_populates="target")
|
||||
time_weight: list["ScheduleTimeWeight"] = Relationship(back_populates="target")
|
||||
|
||||
|
||||
class ScheduleTimeWeight(Base):
|
||||
__tablename__ = "schedule_time_weight"
|
||||
class ScheduleTimeWeight(Model, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
target_id: Optional[int] = Field(
|
||||
default=None, foreign_key="nonebot_bison_target.id"
|
||||
)
|
||||
start_time: Optional[datetime.time]
|
||||
end_time: Optional[datetime.time]
|
||||
weight: Optional[int]
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
target_id = Column(Integer, ForeignKey(Target.id))
|
||||
start_time = Column(Time)
|
||||
end_time = Column(Time)
|
||||
weight = Column(Integer)
|
||||
target: Target = Relationship(back_populates="time_weight")
|
||||
|
||||
target = relationship("Target", back_populates="time_weight")
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class Subscribe(Base):
|
||||
__tablename__ = "subscribe"
|
||||
class Subscribe(Model, table=True):
|
||||
__table_args__ = (
|
||||
UniqueConstraint("target_id", "user_id", name="unique-subscribe-constraint"),
|
||||
)
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
target_id = Column(Integer, ForeignKey(Target.id))
|
||||
user_id = Column(Integer, ForeignKey(User.id))
|
||||
categories = Column(JSON)
|
||||
tags = Column(JSON)
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
target_id: Optional[int] = Field(
|
||||
default=None, foreign_key="nonebot_bison_target.id"
|
||||
)
|
||||
user_id: Optional[int] = Field(default=None, foreign_key="nonebot_bison_user.id")
|
||||
categories: list[Category] = Field(sa_column=Column(JSON))
|
||||
tags: list[Tag] = Field(sa_column=Column(JSON))
|
||||
|
||||
target = relationship("Target", back_populates="subscribes")
|
||||
user = relationship("User", back_populates="subscribes")
|
||||
target: Target = Relationship(back_populates="subscribes")
|
||||
user: User = Relationship(back_populates="subscribes")
|
||||
|
@ -0,0 +1,34 @@
|
||||
"""rename tables
|
||||
|
||||
Revision ID: 5da28f6facb3
|
||||
Revises: 5f3370328e44
|
||||
Create Date: 2023-01-15 19:04:54.987491
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5da28f6facb3"
|
||||
down_revision = "5f3370328e44"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.rename_table("target", "nonebot_bison_target")
|
||||
op.rename_table("user", "nonebot_bison_user")
|
||||
op.rename_table("schedule_time_weight", "nonebot_bison_scheduletimeweight")
|
||||
op.rename_table("subscribe", "nonebot_bison_subscribe")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.rename_table("nonebot_bison_subscribe", "subscribe")
|
||||
op.rename_table("nonebot_bison_scheduletimeweight", "schedule_time_weight")
|
||||
op.rename_table("nonebot_bison_user", "user")
|
||||
op.rename_table("nonebot_bison_target", "target")
|
||||
# ### end Alembic commands ###
|
@ -1,16 +1,14 @@
|
||||
import pytest
|
||||
from nonebug.app import App
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.sql.functions import func
|
||||
from sqlmodel.sql.expression import select
|
||||
|
||||
|
||||
async def test_add_subscribe(app: App, init_scheduler):
|
||||
|
||||
from nonebot_bison.config.db_config import config
|
||||
from nonebot_bison.config.db_model import Subscribe, Target, User
|
||||
from nonebot_bison.types import Target as TTarget
|
||||
from nonebot_plugin_datastore.db import get_engine
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import select
|
||||
|
||||
await config.add_subscribe(
|
||||
user=123,
|
||||
@ -74,7 +72,6 @@ async def test_add_subscribe(app: App, init_scheduler):
|
||||
|
||||
|
||||
async def test_add_dup_sub(init_scheduler):
|
||||
|
||||
from nonebot_bison.config.db_config import SubscribeDupException, config
|
||||
from nonebot_bison.types import Target as TTarget
|
||||
|
||||
@ -102,9 +99,12 @@ async def test_add_dup_sub(init_scheduler):
|
||||
|
||||
async def test_del_subsribe(init_scheduler):
|
||||
from nonebot_bison.config.db_config import config
|
||||
from nonebot_bison.config.db_model import Subscribe, Target, User
|
||||
from nonebot_bison.config.db_model import Subscribe, Target
|
||||
from nonebot_bison.types import Target as TTarget
|
||||
from nonebot_plugin_datastore.db import get_engine
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.sql.functions import func
|
||||
from sqlmodel.sql.expression import select
|
||||
|
||||
await config.add_subscribe(
|
||||
user=123,
|
||||
|
@ -1,7 +1,7 @@
|
||||
async def test_migration(use_legacy_config):
|
||||
from nonebot_bison.config.config_legacy import config as config_legacy
|
||||
from nonebot_bison.config.db import upgrade_db
|
||||
from nonebot_bison.config.db_config import config
|
||||
from nonebot_plugin_datastore.db import init_db
|
||||
|
||||
config_legacy.add_subscribe(
|
||||
user=123,
|
||||
@ -31,7 +31,7 @@ async def test_migration(use_legacy_config):
|
||||
tags=[],
|
||||
)
|
||||
# await data_migrate()
|
||||
await upgrade_db()
|
||||
await init_db()
|
||||
user123_config = await config.list_subscribe(123, "group")
|
||||
assert len(user123_config) == 2
|
||||
for c in user123_config:
|
||||
@ -55,8 +55,8 @@ async def test_migration(use_legacy_config):
|
||||
|
||||
async def test_migrate_dup(use_legacy_config):
|
||||
from nonebot_bison.config.config_legacy import config as config_legacy
|
||||
from nonebot_bison.config.db import upgrade_db
|
||||
from nonebot_bison.config.db_config import config
|
||||
from nonebot_plugin_datastore.db import init_db
|
||||
|
||||
config_legacy.add_subscribe(
|
||||
user=123,
|
||||
@ -77,6 +77,6 @@ async def test_migrate_dup(use_legacy_config):
|
||||
tags=[],
|
||||
)
|
||||
# await data_migrate()
|
||||
await upgrade_db()
|
||||
await init_db()
|
||||
user123_config = await config.list_subscribe(123, "group")
|
||||
assert len(user123_config) == 1
|
||||
|
@ -22,6 +22,7 @@ async def app(nonebug_init: None, tmp_path: Path, monkeypatch: pytest.MonkeyPatc
|
||||
config.superusers = {"10001"}
|
||||
config.log_level = "TRACE"
|
||||
config.bison_filter_log = False
|
||||
nonebot.require("nonebot_bison")
|
||||
return App(monkeypatch)
|
||||
|
||||
|
||||
@ -35,11 +36,10 @@ def dummy_user_subinfo(app: App):
|
||||
|
||||
@pytest.fixture
|
||||
async def db_migration(app: App):
|
||||
from nonebot_bison.config.db import upgrade_db
|
||||
from nonebot_bison.config.db_model import Subscribe, Target, User
|
||||
from nonebot_plugin_datastore.db import get_engine
|
||||
from nonebot_plugin_datastore.db import get_engine, init_db
|
||||
|
||||
await upgrade_db()
|
||||
await init_db()
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
await sess.execute(delete(User))
|
||||
await sess.execute(delete(Subscribe))
|
||||
|
Loading…
x
Reference in New Issue
Block a user