mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2026-05-09 18:27:56 +08:00
⬆️ 适配 Pydantic V2 (#484)
* ⬆️ 适配 Pydantic V2 * 🐛 修复测试报错 * 🐛 适配忘记的 from_orm * 🐛 忘记的 class-based `config` * 🐛 更新 red 适配器版本
This commit is contained in:
@@ -0,0 +1,28 @@
|
||||
from typing import Literal, overload
|
||||
|
||||
from pydantic import BaseModel
|
||||
from nonebot.compat import PYDANTIC_V2
|
||||
|
||||
__all__ = ("model_validator", "model_rebuild")
|
||||
|
||||
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import model_validator as model_validator
|
||||
|
||||
def model_rebuild(model: type[BaseModel]):
|
||||
return model.model_rebuild()
|
||||
|
||||
else:
|
||||
from pydantic import root_validator
|
||||
|
||||
@overload
|
||||
def model_validator(*, mode: Literal["before"]): ...
|
||||
|
||||
@overload
|
||||
def model_validator(*, mode: Literal["after"]): ...
|
||||
|
||||
def model_validator(*, mode: Literal["before", "after"]):
|
||||
return root_validator(pre=mode == "before", allow_reuse=True)
|
||||
|
||||
def model_rebuild(model: type[BaseModel]):
|
||||
return model.update_forward_refs()
|
||||
@@ -3,6 +3,7 @@ from collections import defaultdict
|
||||
from datetime import time, datetime
|
||||
from collections.abc import Callable, Sequence, Awaitable
|
||||
|
||||
from nonebot.compat import model_dump
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy import func, delete, select
|
||||
@@ -46,10 +47,10 @@ class DBConfig:
|
||||
tags: list[Tag],
|
||||
):
|
||||
async with create_session() as session:
|
||||
db_user_stmt = select(User).where(User.user_target == user.dict())
|
||||
db_user_stmt = select(User).where(User.user_target == model_dump(user))
|
||||
db_user: User | None = await session.scalar(db_user_stmt)
|
||||
if not db_user:
|
||||
db_user = User(user_target=user.dict())
|
||||
db_user = User(user_target=model_dump(user))
|
||||
session.add(db_user)
|
||||
db_target_stmt = select(Target).where(Target.platform_name == platform_name).where(Target.target == target)
|
||||
db_target: Target | None = await session.scalar(db_target_stmt)
|
||||
@@ -76,7 +77,7 @@ class DBConfig:
|
||||
async with create_session() as session:
|
||||
query_stmt = (
|
||||
select(Subscribe)
|
||||
.where(User.user_target == user.dict())
|
||||
.where(User.user_target == model_dump(user))
|
||||
.join(User)
|
||||
.options(selectinload(Subscribe.target))
|
||||
)
|
||||
@@ -95,7 +96,7 @@ class DBConfig:
|
||||
|
||||
async def del_subscribe(self, user: PlatformTarget, target: str, platform_name: str):
|
||||
async with create_session() as session:
|
||||
user_obj = await session.scalar(select(User).where(User.user_target == user.dict()))
|
||||
user_obj = await session.scalar(select(User).where(User.user_target == model_dump(user)))
|
||||
target_obj = await session.scalar(
|
||||
select(Target).where(Target.platform_name == platform_name, Target.target == target)
|
||||
)
|
||||
@@ -121,7 +122,7 @@ class DBConfig:
|
||||
subscribe_obj: Subscribe = await sess.scalar(
|
||||
select(Subscribe)
|
||||
.where(
|
||||
User.user_target == user.dict(),
|
||||
User.user_target == model_dump(user),
|
||||
Target.target == target,
|
||||
Target.platform_name == platform_name,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from nonebot.log import logger
|
||||
from nonebot.compat import model_dump
|
||||
from nonebot_plugin_datastore.db import get_engine
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
|
||||
@@ -21,7 +22,7 @@ async def data_migrate():
|
||||
user_target = TargetQQGroup(group_id=user["user"])
|
||||
else:
|
||||
user_target = TargetQQPrivate(user_id=user["user"])
|
||||
db_user = User(user_target=user_target.dict())
|
||||
db_user = User(user_target=model_dump(user_target))
|
||||
user_to_create.append(db_user)
|
||||
user_sub_set = set()
|
||||
for sub in user["subs"]:
|
||||
|
||||
@@ -3,6 +3,7 @@ from pathlib import Path
|
||||
|
||||
from nonebot_plugin_saa import PlatformTarget
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from nonebot.compat import PYDANTIC_V2, ConfigDict
|
||||
from nonebot_plugin_datastore import get_plugin_data
|
||||
from sqlalchemy.orm import Mapped, relationship, mapped_column
|
||||
from sqlalchemy import JSON, String, ForeignKey, UniqueConstraint
|
||||
@@ -46,8 +47,12 @@ class ScheduleTimeWeight(Model):
|
||||
|
||||
target: Mapped[Target] = relationship(back_populates="time_weight")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
else:
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class Subscribe(Model):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC
|
||||
|
||||
from pydantic import BaseModel
|
||||
from nonebot.compat import PYDANTIC_V2, ConfigDict
|
||||
from nonebot_plugin_saa.registries import AllSupportedPlatformTarget as UserInfo
|
||||
|
||||
from ....types import Tag, Category
|
||||
@@ -10,8 +11,12 @@ class NBESFBase(BaseModel, ABC):
|
||||
version: int # 表示nbesf格式版本,有效版本从1开始
|
||||
groups: list = []
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
else:
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class SubReceipt(BaseModel):
|
||||
|
||||
@@ -6,6 +6,7 @@ from functools import partial
|
||||
from nonebot.log import logger
|
||||
from pydantic import BaseModel
|
||||
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
|
||||
from nonebot.compat import PYDANTIC_V2, ConfigDict, model_dump, type_validate_json, type_validate_python
|
||||
|
||||
from ..utils import NBESFParseErr
|
||||
from ....types import Tag, Category
|
||||
@@ -16,14 +17,21 @@ from ...db_config import SubscribeDupException, config
|
||||
NBESF_VERSION = 1
|
||||
|
||||
|
||||
class UserHead(BaseModel, orm_mode=True):
|
||||
class UserHead(BaseModel):
|
||||
"""Bison快递包收货信息"""
|
||||
|
||||
type: str
|
||||
uid: int
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
else:
|
||||
|
||||
class Target(BaseModel, orm_mode=True):
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class Target(BaseModel):
|
||||
"""Bsion快递包发货信息"""
|
||||
|
||||
target_name: str
|
||||
@@ -31,14 +39,28 @@ class Target(BaseModel, orm_mode=True):
|
||||
platform_name: str
|
||||
default_schedule_weight: int
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
else:
|
||||
|
||||
class SubPayload(BaseModel, orm_mode=True):
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class SubPayload(BaseModel):
|
||||
"""Bison快递包里的单件货物"""
|
||||
|
||||
categories: list[Category]
|
||||
tags: list[Tag]
|
||||
target: Target
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
else:
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class SubPack(BaseModel):
|
||||
"""Bison给指定用户派送的快递包"""
|
||||
@@ -56,7 +78,7 @@ class SubGroup(
|
||||
结构参见`nbesf_model`下的对应版本
|
||||
"""
|
||||
|
||||
version = NBESF_VERSION
|
||||
version: int = NBESF_VERSION
|
||||
groups: list[SubPack]
|
||||
|
||||
|
||||
@@ -84,7 +106,7 @@ async def subs_receipt_gen(nbesf_data: SubGroup):
|
||||
tags=sub.tags,
|
||||
)
|
||||
try:
|
||||
await config.add_subscribe(receipt.user, **receipt.dict(exclude={"user"}))
|
||||
await config.add_subscribe(receipt.user, **model_dump(receipt, exclude={"user"}))
|
||||
except SubscribeDupException:
|
||||
logger.warning(f"!添加订阅条目 {repr(receipt)} 失败: 相同的订阅已存在")
|
||||
except Exception as e:
|
||||
@@ -96,9 +118,9 @@ async def subs_receipt_gen(nbesf_data: SubGroup):
|
||||
def nbesf_parser(raw_data: Any) -> SubGroup:
|
||||
try:
|
||||
if isinstance(raw_data, str):
|
||||
nbesf_data = SubGroup.parse_raw(raw_data)
|
||||
nbesf_data = type_validate_json(SubGroup, raw_data)
|
||||
else:
|
||||
nbesf_data = SubGroup.parse_obj(raw_data)
|
||||
nbesf_data = type_validate_python(SubGroup, raw_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("数据解析失败,该数据格式可能不满足NBESF格式标准!")
|
||||
|
||||
@@ -6,6 +6,7 @@ from functools import partial
|
||||
from nonebot.log import logger
|
||||
from pydantic import BaseModel
|
||||
from nonebot_plugin_saa.registries import AllSupportedPlatformTarget
|
||||
from nonebot.compat import PYDANTIC_V2, ConfigDict, model_dump, type_validate_json, type_validate_python
|
||||
|
||||
from ..utils import NBESFParseErr
|
||||
from ....types import Tag, Category
|
||||
@@ -16,7 +17,7 @@ from ...db_config import SubscribeDupException, config
|
||||
NBESF_VERSION = 2
|
||||
|
||||
|
||||
class Target(BaseModel, orm_mode=True):
|
||||
class Target(BaseModel):
|
||||
"""Bsion快递包发货信息"""
|
||||
|
||||
target_name: str
|
||||
@@ -24,14 +25,28 @@ class Target(BaseModel, orm_mode=True):
|
||||
platform_name: str
|
||||
default_schedule_weight: int
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
else:
|
||||
|
||||
class SubPayload(BaseModel, orm_mode=True):
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class SubPayload(BaseModel):
|
||||
"""Bison快递包里的单件货物"""
|
||||
|
||||
categories: list[Category]
|
||||
tags: list[Tag]
|
||||
target: Target
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
else:
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class SubPack(BaseModel):
|
||||
"""Bison给指定用户派送的快递包"""
|
||||
@@ -68,7 +83,7 @@ async def subs_receipt_gen(nbesf_data: SubGroup):
|
||||
tags=sub.tags,
|
||||
)
|
||||
try:
|
||||
await config.add_subscribe(receipt.user, **receipt.dict(exclude={"user"}))
|
||||
await config.add_subscribe(receipt.user, **model_dump(receipt, exclude={"user"}))
|
||||
except SubscribeDupException:
|
||||
logger.warning(f"!添加订阅条目 {repr(receipt)} 失败: 相同的订阅已存在")
|
||||
except Exception as e:
|
||||
@@ -80,9 +95,9 @@ async def subs_receipt_gen(nbesf_data: SubGroup):
|
||||
def nbesf_parser(raw_data: Any) -> SubGroup:
|
||||
try:
|
||||
if isinstance(raw_data, str):
|
||||
nbesf_data = SubGroup.parse_raw(raw_data)
|
||||
nbesf_data = type_validate_json(SubGroup, raw_data)
|
||||
else:
|
||||
nbesf_data = SubGroup.parse_obj(raw_data)
|
||||
nbesf_data = type_validate_python(SubGroup, raw_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("数据解析失败,该数据格式可能不满足NBESF格式标准!")
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy import select
|
||||
from nonebot.log import logger
|
||||
from sqlalchemy.sql.selectable import Select
|
||||
from nonebot_plugin_saa import PlatformTarget
|
||||
from nonebot.compat import type_validate_python
|
||||
from nonebot_plugin_datastore.db import create_session
|
||||
from sqlalchemy.orm.strategy_options import selectinload
|
||||
|
||||
@@ -37,7 +38,7 @@ async def subscribes_export(selector: Callable[[Select], Select]) -> v2.SubGroup
|
||||
user_id_sub_dict: dict[int, list[v2.SubPayload]] = defaultdict(list)
|
||||
|
||||
for sub in sub_data:
|
||||
sub_paylaod = v2.SubPayload.from_orm(sub)
|
||||
sub_paylaod = type_validate_python(v2.SubPayload, sub)
|
||||
user_id_sub_dict[sub.user_id].append(sub_paylaod)
|
||||
|
||||
for user in user_data:
|
||||
|
||||
@@ -4,6 +4,7 @@ from functools import partial
|
||||
from httpx import AsyncClient
|
||||
from bs4 import BeautifulSoup as bs
|
||||
from pydantic import Field, BaseModel
|
||||
from nonebot.compat import type_validate_python
|
||||
|
||||
from ..post import Post
|
||||
from ..types import Target, RawPost, Category
|
||||
@@ -28,9 +29,6 @@ class BulletinListItem(BaseModel):
|
||||
class BulletinList(BaseModel):
|
||||
list: list[BulletinListItem]
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
class BulletinData(BaseModel):
|
||||
cid: str
|
||||
@@ -76,7 +74,7 @@ class Arknights(NewMessage):
|
||||
|
||||
async def get_sub_list(self, _) -> list[BulletinListItem]:
|
||||
raw_data = await self.client.get("https://ak-webview.hypergryph.com/api/game/bulletinList?target=IOS")
|
||||
return ArkBulletinListResponse.parse_obj(raw_data.json()).data.list
|
||||
return type_validate_python(ArkBulletinListResponse, raw_data.json()).data.list
|
||||
|
||||
def get_id(self, post: BulletinListItem) -> Any:
|
||||
return post.cid
|
||||
@@ -95,7 +93,7 @@ class Arknights(NewMessage):
|
||||
raw_data = await self.client.get(
|
||||
f"https://ak-webview.hypergryph.com/api/game/bulletin/{self.get_id(post=raw_post)}"
|
||||
)
|
||||
data = ArkBulletinResponse.parse_obj(raw_data.json()).data
|
||||
data = type_validate_python(ArkBulletinResponse, raw_data.json()).data
|
||||
|
||||
def title_escape(text: str) -> str:
|
||||
return text.replace("\\n", " - ")
|
||||
|
||||
@@ -9,6 +9,9 @@ from datetime import datetime, timedelta
|
||||
from httpx import AsyncClient
|
||||
from nonebot.log import logger
|
||||
from pydantic import Field, BaseModel
|
||||
from nonebot.compat import type_validate_python
|
||||
|
||||
from nonebot_bison.compat import model_rebuild
|
||||
|
||||
from ..post import Post
|
||||
from ..utils import SchedulerConfig, text_similarity
|
||||
@@ -303,7 +306,7 @@ class Bilibililive(StatusChange):
|
||||
infos = []
|
||||
for target in targets:
|
||||
if target in data.keys():
|
||||
infos.append(self.Info.parse_obj(data[target]))
|
||||
infos.append(type_validate_python(self.Info, data[target]))
|
||||
else:
|
||||
infos.append(self._gen_empty_info(int(target)))
|
||||
return infos
|
||||
@@ -428,4 +431,4 @@ class BilibiliBangumi(StatusChange):
|
||||
)
|
||||
|
||||
|
||||
Bilibililive.Info.update_forward_refs()
|
||||
model_rebuild(Bilibililive.Info)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import nonebot
|
||||
from pydantic import Field, BaseSettings
|
||||
from nonebot import get_plugin_config
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
global_config = nonebot.get_driver().config
|
||||
PlatformName = str
|
||||
ThemeName = str
|
||||
|
||||
|
||||
class PlugConfig(BaseSettings):
|
||||
class PlugConfig(BaseModel):
|
||||
bison_config_path: str = ""
|
||||
bison_use_pic: bool = Field(
|
||||
default=False,
|
||||
@@ -22,7 +23,7 @@ class PlugConfig(BaseSettings):
|
||||
bison_use_pic_merge: int = 0 # 多图片时启用图片合并转发(仅限群)
|
||||
# 0:不启用;1:首条消息单独发送,剩余照片合并转发;2以及以上:所有消息全部合并转发
|
||||
bison_resend_times: int = 0
|
||||
bison_proxy: str | None
|
||||
bison_proxy: str | None = None
|
||||
bison_ua: str = Field(
|
||||
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.103 Safari/537.36",
|
||||
description="默认UA",
|
||||
@@ -34,8 +35,5 @@ class PlugConfig(BaseSettings):
|
||||
def outer_url(self) -> str:
|
||||
return self.bison_outer_url or f"http://localhost:{global_config.port}/bison/"
|
||||
|
||||
class Config:
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
plugin_config = PlugConfig(**global_config.dict())
|
||||
plugin_config = get_plugin_config(PlugConfig)
|
||||
|
||||
@@ -8,6 +8,7 @@ from functools import wraps, partial
|
||||
from collections.abc import Callable, Coroutine
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.compat import model_dump
|
||||
|
||||
from ..scheduler.manager import init_scheduler
|
||||
from ..config.subs_io.nbesf_model import v1, v2
|
||||
@@ -95,14 +96,14 @@ async def subs_export(path: Path, format: str):
|
||||
# If stream is None, it returns the produced stream.
|
||||
# safe_dump produces only standard YAML tags and cannot represent an arbitrary Python object.
|
||||
# 进行以下曲线救国方案
|
||||
json_data = json.dumps(export_data.dict(), ensure_ascii=False)
|
||||
json_data = json.dumps(model_dump(export_data), ensure_ascii=False)
|
||||
yaml_data = pyyaml.safe_load(json_data)
|
||||
pyyaml.safe_dump(yaml_data, f, sort_keys=False)
|
||||
|
||||
case "json":
|
||||
logger.info("正在导出为json...")
|
||||
|
||||
json.dump(export_data.dict(), f, indent=4, ensure_ascii=False)
|
||||
json.dump(model_dump(export_data), f, indent=4, ensure_ascii=False)
|
||||
|
||||
case _:
|
||||
raise click.BadParameter(message=f"不支持的导出格式: {format}")
|
||||
|
||||
@@ -3,9 +3,10 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import jinja2
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pydantic import BaseModel
|
||||
from nonebot_plugin_saa import Text, Image, MessageSegmentFactory
|
||||
|
||||
from nonebot_bison.compat import model_validator
|
||||
from nonebot_bison.theme.utils import convert_to_qr
|
||||
from nonebot_bison.theme import Theme, ThemeRenderError, ThemeRenderUnsupportError
|
||||
|
||||
@@ -35,7 +36,7 @@ class CeoboContent(BaseModel):
|
||||
image: str | None
|
||||
text: str | None
|
||||
|
||||
@root_validator
|
||||
@model_validator(mode="before")
|
||||
def check(cls, values):
|
||||
if values["image"] is None and values["text"] is None:
|
||||
raise ValueError("image and text cannot be both None")
|
||||
|
||||
Reference in New Issue
Block a user