⬆️ 适配 Pydantic V2 (#484)

* ⬆️ 适配 Pydantic V2

* 🐛 修复测试报错

* 🐛 适配忘记的 from_orm

* 🐛 忘记的 class-based `config`

* 🐛 更新 red 适配器版本
This commit is contained in:
uy/sun
2024-02-29 19:21:25 +08:00
committed by GitHub
parent ca68964ee9
commit b6e68730b7
17 changed files with 300 additions and 128 deletions
+6 -5
View File
@@ -3,6 +3,7 @@ from collections import defaultdict
from datetime import time, datetime
from collections.abc import Callable, Sequence, Awaitable
from nonebot.compat import model_dump
from sqlalchemy.orm import selectinload
from sqlalchemy.exc import IntegrityError
from sqlalchemy import func, delete, select
@@ -46,10 +47,10 @@ class DBConfig:
tags: list[Tag],
):
async with create_session() as session:
db_user_stmt = select(User).where(User.user_target == user.dict())
db_user_stmt = select(User).where(User.user_target == model_dump(user))
db_user: User | None = await session.scalar(db_user_stmt)
if not db_user:
db_user = User(user_target=user.dict())
db_user = User(user_target=model_dump(user))
session.add(db_user)
db_target_stmt = select(Target).where(Target.platform_name == platform_name).where(Target.target == target)
db_target: Target | None = await session.scalar(db_target_stmt)
@@ -76,7 +77,7 @@ class DBConfig:
async with create_session() as session:
query_stmt = (
select(Subscribe)
.where(User.user_target == user.dict())
.where(User.user_target == model_dump(user))
.join(User)
.options(selectinload(Subscribe.target))
)
@@ -95,7 +96,7 @@ class DBConfig:
async def del_subscribe(self, user: PlatformTarget, target: str, platform_name: str):
async with create_session() as session:
user_obj = await session.scalar(select(User).where(User.user_target == user.dict()))
user_obj = await session.scalar(select(User).where(User.user_target == model_dump(user)))
target_obj = await session.scalar(
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
@@ -121,7 +122,7 @@ class DBConfig:
subscribe_obj: Subscribe = await sess.scalar(
select(Subscribe)
.where(
User.user_target == user.dict(),
User.user_target == model_dump(user),
Target.target == target,
Target.platform_name == platform_name,
)
+2 -1
View File
@@ -1,4 +1,5 @@
from nonebot.log import logger
from nonebot.compat import model_dump
from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
@@ -21,7 +22,7 @@ async def data_migrate():
user_target = TargetQQGroup(group_id=user["user"])
else:
user_target = TargetQQPrivate(user_id=user["user"])
db_user = User(user_target=user_target.dict())
db_user = User(user_target=model_dump(user_target))
user_to_create.append(db_user)
user_sub_set = set()
for sub in user["subs"]:
+7 -2
View File
@@ -3,6 +3,7 @@ from pathlib import Path
from nonebot_plugin_saa import PlatformTarget
from sqlalchemy.dialects.postgresql import JSONB
from nonebot.compat import PYDANTIC_V2, ConfigDict
from nonebot_plugin_datastore import get_plugin_data
from sqlalchemy.orm import Mapped, relationship, mapped_column
from sqlalchemy import JSON, String, ForeignKey, UniqueConstraint
@@ -46,8 +47,12 @@ class ScheduleTimeWeight(Model):
target: Mapped[Target] = relationship(back_populates="time_weight")
class Config:
arbitrary_types_allowed = True
if PYDANTIC_V2:
model_config = ConfigDict(arbitrary_types_allowed=True)
else:
class Config:
arbitrary_types_allowed = True
class Subscribe(Model):
@@ -1,6 +1,7 @@
from abc import ABC
from pydantic import BaseModel
from nonebot.compat import PYDANTIC_V2, ConfigDict
from nonebot_plugin_saa.registries import AllSupportedPlatformTarget as UserInfo
from ....types import Tag, Category
@@ -10,8 +11,12 @@ class NBESFBase(BaseModel, ABC):
version: int # 表示nbesf格式版本,有效版本从1开始
groups: list = []
class Config:
orm_mode = True
if PYDANTIC_V2:
model_config = ConfigDict(from_attributes=True)
else:
class Config:
orm_mode = True
class SubReceipt(BaseModel):
+29 -7
View File
@@ -6,6 +6,7 @@ from functools import partial
from nonebot.log import logger
from pydantic import BaseModel
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from nonebot.compat import PYDANTIC_V2, ConfigDict, model_dump, type_validate_json, type_validate_python
from ..utils import NBESFParseErr
from ....types import Tag, Category
@@ -16,14 +17,21 @@ from ...db_config import SubscribeDupException, config
NBESF_VERSION = 1
class UserHead(BaseModel, orm_mode=True):
class UserHead(BaseModel):
"""Bison快递包收货信息"""
type: str
uid: int
if PYDANTIC_V2:
model_config = ConfigDict(from_attributes=True)
else:
class Target(BaseModel, orm_mode=True):
class Config:
orm_mode = True
class Target(BaseModel):
"""Bsion快递包发货信息"""
target_name: str
@@ -31,14 +39,28 @@ class Target(BaseModel, orm_mode=True):
platform_name: str
default_schedule_weight: int
if PYDANTIC_V2:
model_config = ConfigDict(from_attributes=True)
else:
class SubPayload(BaseModel, orm_mode=True):
class Config:
orm_mode = True
class SubPayload(BaseModel):
"""Bison快递包里的单件货物"""
categories: list[Category]
tags: list[Tag]
target: Target
if PYDANTIC_V2:
model_config = ConfigDict(from_attributes=True)
else:
class Config:
orm_mode = True
class SubPack(BaseModel):
"""Bison给指定用户派送的快递包"""
@@ -56,7 +78,7 @@ class SubGroup(
结构参见`nbesf_model`下的对应版本
"""
version = NBESF_VERSION
version: int = NBESF_VERSION
groups: list[SubPack]
@@ -84,7 +106,7 @@ async def subs_receipt_gen(nbesf_data: SubGroup):
tags=sub.tags,
)
try:
await config.add_subscribe(receipt.user, **receipt.dict(exclude={"user"}))
await config.add_subscribe(receipt.user, **model_dump(receipt, exclude={"user"}))
except SubscribeDupException:
logger.warning(f"!添加订阅条目 {repr(receipt)} 失败: 相同的订阅已存在")
except Exception as e:
@@ -96,9 +118,9 @@ async def subs_receipt_gen(nbesf_data: SubGroup):
def nbesf_parser(raw_data: Any) -> SubGroup:
try:
if isinstance(raw_data, str):
nbesf_data = SubGroup.parse_raw(raw_data)
nbesf_data = type_validate_json(SubGroup, raw_data)
else:
nbesf_data = SubGroup.parse_obj(raw_data)
nbesf_data = type_validate_python(SubGroup, raw_data)
except Exception as e:
logger.error("数据解析失败,该数据格式可能不满足NBESF格式标准!")
+20 -5
View File
@@ -6,6 +6,7 @@ from functools import partial
from nonebot.log import logger
from pydantic import BaseModel
from nonebot_plugin_saa.registries import AllSupportedPlatformTarget
from nonebot.compat import PYDANTIC_V2, ConfigDict, model_dump, type_validate_json, type_validate_python
from ..utils import NBESFParseErr
from ....types import Tag, Category
@@ -16,7 +17,7 @@ from ...db_config import SubscribeDupException, config
NBESF_VERSION = 2
class Target(BaseModel, orm_mode=True):
class Target(BaseModel):
"""Bsion快递包发货信息"""
target_name: str
@@ -24,14 +25,28 @@ class Target(BaseModel, orm_mode=True):
platform_name: str
default_schedule_weight: int
if PYDANTIC_V2:
model_config = ConfigDict(from_attributes=True)
else:
class SubPayload(BaseModel, orm_mode=True):
class Config:
orm_mode = True
class SubPayload(BaseModel):
"""Bison快递包里的单件货物"""
categories: list[Category]
tags: list[Tag]
target: Target
if PYDANTIC_V2:
model_config = ConfigDict(from_attributes=True)
else:
class Config:
orm_mode = True
class SubPack(BaseModel):
"""Bison给指定用户派送的快递包"""
@@ -68,7 +83,7 @@ async def subs_receipt_gen(nbesf_data: SubGroup):
tags=sub.tags,
)
try:
await config.add_subscribe(receipt.user, **receipt.dict(exclude={"user"}))
await config.add_subscribe(receipt.user, **model_dump(receipt, exclude={"user"}))
except SubscribeDupException:
logger.warning(f"!添加订阅条目 {repr(receipt)} 失败: 相同的订阅已存在")
except Exception as e:
@@ -80,9 +95,9 @@ async def subs_receipt_gen(nbesf_data: SubGroup):
def nbesf_parser(raw_data: Any) -> SubGroup:
try:
if isinstance(raw_data, str):
nbesf_data = SubGroup.parse_raw(raw_data)
nbesf_data = type_validate_json(SubGroup, raw_data)
else:
nbesf_data = SubGroup.parse_obj(raw_data)
nbesf_data = type_validate_python(SubGroup, raw_data)
except Exception as e:
logger.error("数据解析失败,该数据格式可能不满足NBESF格式标准!")
+2 -1
View File
@@ -6,6 +6,7 @@ from sqlalchemy import select
from nonebot.log import logger
from sqlalchemy.sql.selectable import Select
from nonebot_plugin_saa import PlatformTarget
from nonebot.compat import type_validate_python
from nonebot_plugin_datastore.db import create_session
from sqlalchemy.orm.strategy_options import selectinload
@@ -37,7 +38,7 @@ async def subscribes_export(selector: Callable[[Select], Select]) -> v2.SubGroup
user_id_sub_dict: dict[int, list[v2.SubPayload]] = defaultdict(list)
for sub in sub_data:
sub_paylaod = v2.SubPayload.from_orm(sub)
sub_paylaod = type_validate_python(v2.SubPayload, sub)
user_id_sub_dict[sub.user_id].append(sub_paylaod)
for user in user_data: