⬆️ 适配 Pydantic V2 (#484)

* ⬆️ 适配 Pydantic V2

* 🐛 修复测试报错

* 🐛 适配忘记的 from_orm

* 🐛 忘记的 class-based `config`

* 🐛 更新 red 适配器版本
This commit is contained in:
uy/sun
2024-02-29 19:21:25 +08:00
committed by GitHub
parent ca68964ee9
commit b6e68730b7
17 changed files with 300 additions and 128 deletions
+28
View File
@@ -0,0 +1,28 @@
from typing import Literal, overload
from pydantic import BaseModel
from nonebot.compat import PYDANTIC_V2
__all__ = ("model_validator", "model_rebuild")
if PYDANTIC_V2:
from pydantic import model_validator as model_validator
def model_rebuild(model: type[BaseModel]):
return model.model_rebuild()
else:
from pydantic import root_validator
@overload
def model_validator(*, mode: Literal["before"]): ...
@overload
def model_validator(*, mode: Literal["after"]): ...
def model_validator(*, mode: Literal["before", "after"]):
return root_validator(pre=mode == "before", allow_reuse=True)
def model_rebuild(model: type[BaseModel]):
return model.update_forward_refs()
+6 -5
View File
@@ -3,6 +3,7 @@ from collections import defaultdict
from datetime import time, datetime
from collections.abc import Callable, Sequence, Awaitable
from nonebot.compat import model_dump
from sqlalchemy.orm import selectinload
from sqlalchemy.exc import IntegrityError
from sqlalchemy import func, delete, select
@@ -46,10 +47,10 @@ class DBConfig:
tags: list[Tag],
):
async with create_session() as session:
db_user_stmt = select(User).where(User.user_target == user.dict())
db_user_stmt = select(User).where(User.user_target == model_dump(user))
db_user: User | None = await session.scalar(db_user_stmt)
if not db_user:
db_user = User(user_target=user.dict())
db_user = User(user_target=model_dump(user))
session.add(db_user)
db_target_stmt = select(Target).where(Target.platform_name == platform_name).where(Target.target == target)
db_target: Target | None = await session.scalar(db_target_stmt)
@@ -76,7 +77,7 @@ class DBConfig:
async with create_session() as session:
query_stmt = (
select(Subscribe)
.where(User.user_target == user.dict())
.where(User.user_target == model_dump(user))
.join(User)
.options(selectinload(Subscribe.target))
)
@@ -95,7 +96,7 @@ class DBConfig:
async def del_subscribe(self, user: PlatformTarget, target: str, platform_name: str):
async with create_session() as session:
user_obj = await session.scalar(select(User).where(User.user_target == user.dict()))
user_obj = await session.scalar(select(User).where(User.user_target == model_dump(user)))
target_obj = await session.scalar(
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
@@ -121,7 +122,7 @@ class DBConfig:
subscribe_obj: Subscribe = await sess.scalar(
select(Subscribe)
.where(
User.user_target == user.dict(),
User.user_target == model_dump(user),
Target.target == target,
Target.platform_name == platform_name,
)
+2 -1
View File
@@ -1,4 +1,5 @@
from nonebot.log import logger
from nonebot.compat import model_dump
from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
@@ -21,7 +22,7 @@ async def data_migrate():
user_target = TargetQQGroup(group_id=user["user"])
else:
user_target = TargetQQPrivate(user_id=user["user"])
db_user = User(user_target=user_target.dict())
db_user = User(user_target=model_dump(user_target))
user_to_create.append(db_user)
user_sub_set = set()
for sub in user["subs"]:
+7 -2
View File
@@ -3,6 +3,7 @@ from pathlib import Path
from nonebot_plugin_saa import PlatformTarget
from sqlalchemy.dialects.postgresql import JSONB
from nonebot.compat import PYDANTIC_V2, ConfigDict
from nonebot_plugin_datastore import get_plugin_data
from sqlalchemy.orm import Mapped, relationship, mapped_column
from sqlalchemy import JSON, String, ForeignKey, UniqueConstraint
@@ -46,8 +47,12 @@ class ScheduleTimeWeight(Model):
target: Mapped[Target] = relationship(back_populates="time_weight")
class Config:
arbitrary_types_allowed = True
if PYDANTIC_V2:
model_config = ConfigDict(arbitrary_types_allowed=True)
else:
class Config:
arbitrary_types_allowed = True
class Subscribe(Model):
@@ -1,6 +1,7 @@
from abc import ABC
from pydantic import BaseModel
from nonebot.compat import PYDANTIC_V2, ConfigDict
from nonebot_plugin_saa.registries import AllSupportedPlatformTarget as UserInfo
from ....types import Tag, Category
@@ -10,8 +11,12 @@ class NBESFBase(BaseModel, ABC):
version: int # 表示nbesf格式版本,有效版本从1开始
groups: list = []
class Config:
orm_mode = True
if PYDANTIC_V2:
model_config = ConfigDict(from_attributes=True)
else:
class Config:
orm_mode = True
class SubReceipt(BaseModel):
+29 -7
View File
@@ -6,6 +6,7 @@ from functools import partial
from nonebot.log import logger
from pydantic import BaseModel
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from nonebot.compat import PYDANTIC_V2, ConfigDict, model_dump, type_validate_json, type_validate_python
from ..utils import NBESFParseErr
from ....types import Tag, Category
@@ -16,14 +17,21 @@ from ...db_config import SubscribeDupException, config
NBESF_VERSION = 1
class UserHead(BaseModel, orm_mode=True):
class UserHead(BaseModel):
"""Bison快递包收货信息"""
type: str
uid: int
if PYDANTIC_V2:
model_config = ConfigDict(from_attributes=True)
else:
class Target(BaseModel, orm_mode=True):
class Config:
orm_mode = True
class Target(BaseModel):
"""Bsion快递包发货信息"""
target_name: str
@@ -31,14 +39,28 @@ class Target(BaseModel, orm_mode=True):
platform_name: str
default_schedule_weight: int
if PYDANTIC_V2:
model_config = ConfigDict(from_attributes=True)
else:
class SubPayload(BaseModel, orm_mode=True):
class Config:
orm_mode = True
class SubPayload(BaseModel):
"""Bison快递包里的单件货物"""
categories: list[Category]
tags: list[Tag]
target: Target
if PYDANTIC_V2:
model_config = ConfigDict(from_attributes=True)
else:
class Config:
orm_mode = True
class SubPack(BaseModel):
"""Bison给指定用户派送的快递包"""
@@ -56,7 +78,7 @@ class SubGroup(
结构参见`nbesf_model`下的对应版本
"""
version = NBESF_VERSION
version: int = NBESF_VERSION
groups: list[SubPack]
@@ -84,7 +106,7 @@ async def subs_receipt_gen(nbesf_data: SubGroup):
tags=sub.tags,
)
try:
await config.add_subscribe(receipt.user, **receipt.dict(exclude={"user"}))
await config.add_subscribe(receipt.user, **model_dump(receipt, exclude={"user"}))
except SubscribeDupException:
logger.warning(f"!添加订阅条目 {repr(receipt)} 失败: 相同的订阅已存在")
except Exception as e:
@@ -96,9 +118,9 @@ async def subs_receipt_gen(nbesf_data: SubGroup):
def nbesf_parser(raw_data: Any) -> SubGroup:
try:
if isinstance(raw_data, str):
nbesf_data = SubGroup.parse_raw(raw_data)
nbesf_data = type_validate_json(SubGroup, raw_data)
else:
nbesf_data = SubGroup.parse_obj(raw_data)
nbesf_data = type_validate_python(SubGroup, raw_data)
except Exception as e:
logger.error("数据解析失败,该数据格式可能不满足NBESF格式标准!")
+20 -5
View File
@@ -6,6 +6,7 @@ from functools import partial
from nonebot.log import logger
from pydantic import BaseModel
from nonebot_plugin_saa.registries import AllSupportedPlatformTarget
from nonebot.compat import PYDANTIC_V2, ConfigDict, model_dump, type_validate_json, type_validate_python
from ..utils import NBESFParseErr
from ....types import Tag, Category
@@ -16,7 +17,7 @@ from ...db_config import SubscribeDupException, config
NBESF_VERSION = 2
class Target(BaseModel, orm_mode=True):
class Target(BaseModel):
"""Bsion快递包发货信息"""
target_name: str
@@ -24,14 +25,28 @@ class Target(BaseModel, orm_mode=True):
platform_name: str
default_schedule_weight: int
if PYDANTIC_V2:
model_config = ConfigDict(from_attributes=True)
else:
class SubPayload(BaseModel, orm_mode=True):
class Config:
orm_mode = True
class SubPayload(BaseModel):
"""Bison快递包里的单件货物"""
categories: list[Category]
tags: list[Tag]
target: Target
if PYDANTIC_V2:
model_config = ConfigDict(from_attributes=True)
else:
class Config:
orm_mode = True
class SubPack(BaseModel):
"""Bison给指定用户派送的快递包"""
@@ -68,7 +83,7 @@ async def subs_receipt_gen(nbesf_data: SubGroup):
tags=sub.tags,
)
try:
await config.add_subscribe(receipt.user, **receipt.dict(exclude={"user"}))
await config.add_subscribe(receipt.user, **model_dump(receipt, exclude={"user"}))
except SubscribeDupException:
logger.warning(f"!添加订阅条目 {repr(receipt)} 失败: 相同的订阅已存在")
except Exception as e:
@@ -80,9 +95,9 @@ async def subs_receipt_gen(nbesf_data: SubGroup):
def nbesf_parser(raw_data: Any) -> SubGroup:
try:
if isinstance(raw_data, str):
nbesf_data = SubGroup.parse_raw(raw_data)
nbesf_data = type_validate_json(SubGroup, raw_data)
else:
nbesf_data = SubGroup.parse_obj(raw_data)
nbesf_data = type_validate_python(SubGroup, raw_data)
except Exception as e:
logger.error("数据解析失败,该数据格式可能不满足NBESF格式标准!")
+2 -1
View File
@@ -6,6 +6,7 @@ from sqlalchemy import select
from nonebot.log import logger
from sqlalchemy.sql.selectable import Select
from nonebot_plugin_saa import PlatformTarget
from nonebot.compat import type_validate_python
from nonebot_plugin_datastore.db import create_session
from sqlalchemy.orm.strategy_options import selectinload
@@ -37,7 +38,7 @@ async def subscribes_export(selector: Callable[[Select], Select]) -> v2.SubGroup
user_id_sub_dict: dict[int, list[v2.SubPayload]] = defaultdict(list)
for sub in sub_data:
sub_paylaod = v2.SubPayload.from_orm(sub)
sub_paylaod = type_validate_python(v2.SubPayload, sub)
user_id_sub_dict[sub.user_id].append(sub_paylaod)
for user in user_data:
+3 -5
View File
@@ -4,6 +4,7 @@ from functools import partial
from httpx import AsyncClient
from bs4 import BeautifulSoup as bs
from pydantic import Field, BaseModel
from nonebot.compat import type_validate_python
from ..post import Post
from ..types import Target, RawPost, Category
@@ -28,9 +29,6 @@ class BulletinListItem(BaseModel):
class BulletinList(BaseModel):
list: list[BulletinListItem]
class Config:
extra = "ignore"
class BulletinData(BaseModel):
cid: str
@@ -76,7 +74,7 @@ class Arknights(NewMessage):
async def get_sub_list(self, _) -> list[BulletinListItem]:
raw_data = await self.client.get("https://ak-webview.hypergryph.com/api/game/bulletinList?target=IOS")
return ArkBulletinListResponse.parse_obj(raw_data.json()).data.list
return type_validate_python(ArkBulletinListResponse, raw_data.json()).data.list
def get_id(self, post: BulletinListItem) -> Any:
return post.cid
@@ -95,7 +93,7 @@ class Arknights(NewMessage):
raw_data = await self.client.get(
f"https://ak-webview.hypergryph.com/api/game/bulletin/{self.get_id(post=raw_post)}"
)
data = ArkBulletinResponse.parse_obj(raw_data.json()).data
data = type_validate_python(ArkBulletinResponse, raw_data.json()).data
def title_escape(text: str) -> str:
return text.replace("\\n", " - ")
+5 -2
View File
@@ -9,6 +9,9 @@ from datetime import datetime, timedelta
from httpx import AsyncClient
from nonebot.log import logger
from pydantic import Field, BaseModel
from nonebot.compat import type_validate_python
from nonebot_bison.compat import model_rebuild
from ..post import Post
from ..utils import SchedulerConfig, text_similarity
@@ -303,7 +306,7 @@ class Bilibililive(StatusChange):
infos = []
for target in targets:
if target in data.keys():
infos.append(self.Info.parse_obj(data[target]))
infos.append(type_validate_python(self.Info, data[target]))
else:
infos.append(self._gen_empty_info(int(target)))
return infos
@@ -428,4 +431,4 @@ class BilibiliBangumi(StatusChange):
)
Bilibililive.Info.update_forward_refs()
model_rebuild(Bilibililive.Info)
+5 -7
View File
@@ -1,12 +1,13 @@
import nonebot
from pydantic import Field, BaseSettings
from nonebot import get_plugin_config
from pydantic import Field, BaseModel
global_config = nonebot.get_driver().config
PlatformName = str
ThemeName = str
class PlugConfig(BaseSettings):
class PlugConfig(BaseModel):
bison_config_path: str = ""
bison_use_pic: bool = Field(
default=False,
@@ -22,7 +23,7 @@ class PlugConfig(BaseSettings):
bison_use_pic_merge: int = 0 # 多图片时启用图片合并转发(仅限群)
# 0:不启用;1:首条消息单独发送,剩余照片合并转发;2以及以上:所有消息全部合并转发
bison_resend_times: int = 0
bison_proxy: str | None
bison_proxy: str | None = None
bison_ua: str = Field(
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.103 Safari/537.36",
description="默认UA",
@@ -34,8 +35,5 @@ class PlugConfig(BaseSettings):
def outer_url(self) -> str:
return self.bison_outer_url or f"http://localhost:{global_config.port}/bison/"
class Config:
extra = "ignore"
plugin_config = PlugConfig(**global_config.dict())
plugin_config = get_plugin_config(PlugConfig)
+3 -2
View File
@@ -8,6 +8,7 @@ from functools import wraps, partial
from collections.abc import Callable, Coroutine
from nonebot.log import logger
from nonebot.compat import model_dump
from ..scheduler.manager import init_scheduler
from ..config.subs_io.nbesf_model import v1, v2
@@ -95,14 +96,14 @@ async def subs_export(path: Path, format: str):
# If stream is None, it returns the produced stream.
# safe_dump produces only standard YAML tags and cannot represent an arbitrary Python object.
# 进行以下曲线救国方案
json_data = json.dumps(export_data.dict(), ensure_ascii=False)
json_data = json.dumps(model_dump(export_data), ensure_ascii=False)
yaml_data = pyyaml.safe_load(json_data)
pyyaml.safe_dump(yaml_data, f, sort_keys=False)
case "json":
logger.info("正在导出为json...")
json.dump(export_data.dict(), f, indent=4, ensure_ascii=False)
json.dump(model_dump(export_data), f, indent=4, ensure_ascii=False)
case _:
raise click.BadParameter(message=f"不支持的导出格式: {format}")
@@ -3,9 +3,10 @@ from datetime import datetime
from typing import TYPE_CHECKING, Literal
import jinja2
from pydantic import BaseModel, root_validator
from pydantic import BaseModel
from nonebot_plugin_saa import Text, Image, MessageSegmentFactory
from nonebot_bison.compat import model_validator
from nonebot_bison.theme.utils import convert_to_qr
from nonebot_bison.theme import Theme, ThemeRenderError, ThemeRenderUnsupportError
@@ -35,7 +36,7 @@ class CeoboContent(BaseModel):
image: str | None
text: str | None
@root_validator
@model_validator(mode="before")
def check(cls, values):
if values["image"] is None and values["text"] is None:
raise ValueError("image and text cannot be both None")