🎨 按ruff的检查调整程序代码

This commit is contained in:
Azide
2023-07-16 00:22:20 +08:00
committed by felinae98
parent f232ce4c3e
commit dba8f2a9cb
42 changed files with 414 additions and 757 deletions
+4 -2
View File
@@ -1,2 +1,4 @@
from .db_config import config
from .utils import NoSuchSubscribeException, NoSuchTargetException, NoSuchUserException
from .db_config import config as config
from .utils import NoSuchUserException as NoSuchUserException
from .utils import NoSuchTargetException as NoSuchTargetException
from .utils import NoSuchSubscribeException as NoSuchSubscribeException
+28 -45
View File
@@ -1,20 +1,19 @@
import json
import os
from collections import defaultdict
from datetime import datetime
import json
from os import path
from pathlib import Path
from typing import DefaultDict, Literal, Mapping, TypedDict
from datetime import datetime
from collections import defaultdict
from typing import Literal, TypedDict
import nonebot
from nonebot.log import logger
from tinydb import Query, TinyDB
from ..utils import Singleton
from ..types import User, Target
from ..platform import platform_manager
from ..plugin_config import plugin_config
from ..types import Target, User
from ..utils import Singleton
from .utils import NoSuchSubscribeException, NoSuchUserException
from .utils import NoSuchUserException, NoSuchSubscribeException
supported_target_type = platform_manager.keys()
@@ -89,17 +88,16 @@ class Config(metaclass=Singleton):
self.target_user_cat_cache = {}
self.target_user_tag_cache = {}
self.target_list = {}
self.next_index: DefaultDict[str, int] = defaultdict(lambda: 0)
self.next_index: defaultdict[str, int] = defaultdict(lambda: 0)
else:
self.available = False
def add_subscribe(
self, user, user_type, target, target_name, target_type, cats, tags
):
def add_subscribe(self, user, user_type, target, target_name, target_type, cats, tags):
user_query = Query()
query = (user_query.user == user) & (user_query.user_type == user_type)
if user_data := self.user_target.get(query):
# update
assert not isinstance(user_data, list)
subs: list = user_data.get("subs", [])
subs.append(
{
@@ -132,9 +130,8 @@ class Config(metaclass=Singleton):
def list_subscribe(self, user, user_type) -> list[SubscribeContent]:
query = Query()
if user_sub := self.user_target.get(
(query.user == user) & (query.user_type == user_type)
):
if user_sub := self.user_target.get((query.user == user) & (query.user_type == user_type)):
assert not isinstance(user_sub, list)
return user_sub["subs"]
return []
@@ -146,6 +143,7 @@ class Config(metaclass=Singleton):
query = (user_query.user == user) & (user_query.user_type == user_type)
if not (query_res := self.user_target.get(query)):
raise NoSuchUserException()
assert not isinstance(query_res, list)
subs = query_res.get("subs", [])
for idx, sub in enumerate(subs):
if sub.get("target") == target and sub.get("target_type") == target_type:
@@ -155,13 +153,12 @@ class Config(metaclass=Singleton):
return
raise NoSuchSubscribeException()
def update_subscribe(
self, user, user_type, target, target_name, target_type, cats, tags
):
def update_subscribe(self, user, user_type, target, target_name, target_type, cats, tags):
user_query = Query()
query = (user_query.user == user) & (user_query.user_type == user_type)
if user_data := self.user_target.get(query):
# update
assert not isinstance(user_data, list)
subs: list = user_data.get("subs", [])
find_flag = False
for item in subs:
@@ -182,19 +179,13 @@ class Config(metaclass=Singleton):
def update_send_cache(self):
res = {target_type: defaultdict(list) for target_type in supported_target_type}
cat_res = {
target_type: defaultdict(lambda: defaultdict(list))
for target_type in supported_target_type
}
tag_res = {
target_type: defaultdict(lambda: defaultdict(list))
for target_type in supported_target_type
}
cat_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
tag_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
# res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
to_del = []
for user in self.user_target.all():
for sub in user.get("subs", []):
if not sub.get("target_type") in supported_target_type:
if sub.get("target_type") not in supported_target_type:
to_del.append(
{
"user": user["user"],
@@ -204,36 +195,28 @@ class Config(metaclass=Singleton):
}
)
continue
res[sub["target_type"]][sub["target"]].append(
User(user["user"], user["user_type"])
)
cat_res[sub["target_type"]][sub["target"]][
"{}-{}".format(user["user_type"], user["user"])
] = sub["cats"]
tag_res[sub["target_type"]][sub["target"]][
"{}-{}".format(user["user_type"], user["user"])
] = sub["tags"]
res[sub["target_type"]][sub["target"]].append(User(user["user"], user["user_type"]))
cat_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[
"cats"
]
tag_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[
"tags"
]
self.target_user_cache = res
self.target_user_cat_cache = cat_res
self.target_user_tag_cache = tag_res
for target_type in self.target_user_cache:
self.target_list[target_type] = list(
self.target_user_cache[target_type].keys()
)
self.target_list[target_type] = list(self.target_user_cache[target_type].keys())
logger.info(f"Deleting {to_del}")
for d in to_del:
self.del_subscribe(**d)
def get_sub_category(self, target_type, target, user_type, user):
return self.target_user_cat_cache[target_type][target][
"{}-{}".format(user_type, user)
]
return self.target_user_cat_cache[target_type][target][f"{user_type}-{user}"]
def get_sub_tags(self, target_type, target, user_type, user):
return self.target_user_tag_cache[target_type][target][
"{}-{}".format(user_type, user)
]
return self.target_user_tag_cache[target_type][target][f"{user_type}-{user}"]
def get_next_target(self, target_type):
# FIXME 插入或删除target后对队列的影响(但是并不是大问题
+40 -89
View File
@@ -1,19 +1,19 @@
import asyncio
from collections import defaultdict
from datetime import datetime, time
from typing import Awaitable, Callable, Optional, Sequence
from datetime import time, datetime
from collections.abc import Callable, Sequence, Awaitable
from nonebot_plugin_datastore import create_session
from nonebot_plugin_saa import PlatformTarget
from sqlalchemy import delete, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import selectinload
from sqlalchemy.exc import IntegrityError
from sqlalchemy import func, delete, select
from nonebot_plugin_saa import PlatformTarget
from nonebot_plugin_datastore import create_session
from ..types import Category, PlatformWeightConfigResp, Tag
from ..types import Tag
from ..types import Target as T_Target
from ..types import TimeWeightConfig, UserSubInfo, WeightConfig
from .db_model import ScheduleTimeWeight, Subscribe, Target, User
from .utils import NoSuchTargetException
from .db_model import User, Target, Subscribe, ScheduleTimeWeight
from ..types import Category, UserSubInfo, WeightConfig, TimeWeightConfig, PlatformWeightConfigResp
def _get_time():
@@ -48,23 +48,17 @@ class DBConfig:
):
async with create_session() as session:
db_user_stmt = select(User).where(User.user_target == user.dict())
db_user: Optional[User] = await session.scalar(db_user_stmt)
db_user: User | None = await session.scalar(db_user_stmt)
if not db_user:
db_user = User(user_target=user.dict())
session.add(db_user)
db_target_stmt = (
select(Target)
.where(Target.platform_name == platform_name)
.where(Target.target == target)
select(Target).where(Target.platform_name == platform_name).where(Target.target == target)
)
db_target: Optional[Target] = await session.scalar(db_target_stmt)
db_target: Target | None = await session.scalar(db_target_stmt)
if not db_target:
db_target = Target(
target=target, platform_name=platform_name, target_name=target_name
)
await asyncio.gather(
*[hook(platform_name, target) for hook in self.add_target_hook]
)
db_target = Target(target=target, platform_name=platform_name, target_name=target_name)
await asyncio.gather(*[hook(platform_name, target) for hook in self.add_target_hook])
else:
db_target.target_name = target_name
subscribe = Subscribe(
@@ -96,44 +90,25 @@ class DBConfig:
"""获取数据库中带有user、target信息的subscribe数据"""
async with create_session() as session:
query_stmt = (
select(Subscribe)
.join(User)
.options(selectinload(Subscribe.target), selectinload(Subscribe.user))
select(Subscribe).join(User).options(selectinload(Subscribe.target), selectinload(Subscribe.user))
)
subs = (await session.scalars(query_stmt)).all()
return subs
async def del_subscribe(
self, user: PlatformTarget, target: str, platform_name: str
):
async def del_subscribe(self, user: PlatformTarget, target: str, platform_name: str):
async with create_session() as session:
user_obj = await session.scalar(
select(User).where(User.user_target == user.dict())
)
user_obj = await session.scalar(select(User).where(User.user_target == user.dict()))
target_obj = await session.scalar(
select(Target).where(
Target.platform_name == platform_name, Target.target == target
)
)
await session.execute(
delete(Subscribe).where(
Subscribe.user == user_obj, Subscribe.target == target_obj
)
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
await session.execute(delete(Subscribe).where(Subscribe.user == user_obj, Subscribe.target == target_obj))
target_count = await session.scalar(
select(func.count())
.select_from(Subscribe)
.where(Subscribe.target == target_obj)
select(func.count()).select_from(Subscribe).where(Subscribe.target == target_obj)
)
if target_count == 0:
# delete empty target
await asyncio.gather(
*[
hook(platform_name, T_Target(target))
for hook in self.delete_target_hook
]
)
await asyncio.gather(*[hook(platform_name, T_Target(target)) for hook in self.delete_target_hook])
await session.commit()
async def update_subscribe(
@@ -165,29 +140,22 @@ class DBConfig:
async def get_platform_target(self, platform_name: str) -> Sequence[Target]:
async with create_session() as sess:
subq = select(Subscribe.target_id).distinct().subquery()
query = (
select(Target).join(subq).where(Target.platform_name == platform_name)
)
query = select(Target).join(subq).where(Target.platform_name == platform_name)
return (await sess.scalars(query)).all()
async def get_time_weight_config(
self, target: T_Target, platform_name: str
) -> WeightConfig:
async def get_time_weight_config(self, target: T_Target, platform_name: str) -> WeightConfig:
async with create_session() as sess:
time_weight_conf = (
await sess.scalars(
select(ScheduleTimeWeight)
.where(
Target.platform_name == platform_name, Target.target == target
)
.where(Target.platform_name == platform_name, Target.target == target)
.join(Target)
)
).all()
targetObj = await sess.scalar(
select(Target).where(
Target.platform_name == platform_name, Target.target == target
)
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
assert targetObj
return WeightConfig(
default=targetObj.default_schedule_weight,
time_config=[
@@ -200,22 +168,16 @@ class DBConfig:
],
)
async def update_time_weight_config(
self, target: T_Target, platform_name: str, conf: WeightConfig
):
async def update_time_weight_config(self, target: T_Target, platform_name: str, conf: WeightConfig):
async with create_session() as sess:
targetObj = await sess.scalar(
select(Target).where(
Target.platform_name == platform_name, Target.target == target
)
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
if not targetObj:
raise NoSuchTargetException()
target_id = targetObj.id
targetObj.default_schedule_weight = conf.default
delete_statement = delete(ScheduleTimeWeight).where(
ScheduleTimeWeight.target_id == target_id
)
delete_statement = delete(ScheduleTimeWeight).where(ScheduleTimeWeight.target_id == target_id)
await sess.execute(delete_statement)
for time_conf in conf.time_config:
new_conf = ScheduleTimeWeight(
@@ -243,18 +205,13 @@ class DBConfig:
key = f"{target.platform_name}-{target.target}"
weight = target.default_schedule_weight
for time_conf in target.time_weight:
if (
time_conf.start_time <= cur_time
and time_conf.end_time > cur_time
):
if time_conf.start_time <= cur_time and time_conf.end_time > cur_time:
weight = time_conf.weight
break
res[key] = weight
return res
async def get_platform_target_subscribers(
self, platform_name: str, target: T_Target
) -> list[UserSubInfo]:
async def get_platform_target_subscribers(self, platform_name: str, target: T_Target) -> list[UserSubInfo]:
async with create_session() as sess:
query = (
select(Subscribe)
@@ -263,16 +220,14 @@ class DBConfig:
.options(selectinload(Subscribe.user))
)
subsribes = (await sess.scalars(query)).all()
return list(
map(
lambda subscribe: UserSubInfo(
PlatformTarget.deserialize(subscribe.user.user_target),
subscribe.categories,
subscribe.tags,
),
subsribes,
return [
UserSubInfo(
PlatformTarget.deserialize(subscribe.user.user_target),
subscribe.categories,
subscribe.tags,
)
)
for subscribe in subsribes
]
async def get_all_weight_config(
self,
@@ -281,9 +236,7 @@ class DBConfig:
async with create_session() as sess:
query = select(Target)
targets = (await sess.scalars(query)).all()
query = select(ScheduleTimeWeight).options(
selectinload(ScheduleTimeWeight.target)
)
query = select(ScheduleTimeWeight).options(selectinload(ScheduleTimeWeight.target))
time_weights = (await sess.scalars(query)).all()
for target in targets:
@@ -293,9 +246,7 @@ class DBConfig:
target=T_Target(target.target),
target_name=target.target_name,
platform_name=platform_name,
weight=WeightConfig(
default=target.default_schedule_weight, time_config=[]
),
weight=WeightConfig(default=target.default_schedule_weight, time_config=[]),
)
for time_weight_config in time_weights:
+5 -15
View File
@@ -1,22 +1,17 @@
from nonebot.log import logger
from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from sqlalchemy.ext.asyncio.session import AsyncSession
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from .db_model import User, Target, Subscribe
from .config_legacy import Config, ConfigContent, drop
from .db_model import Subscribe, Target, User
async def data_migrate():
config = Config()
if config.available:
logger.warning("You are still using legacy db, migrating to sqlite")
all_subs: list[ConfigContent] = list(
map(
lambda item: ConfigContent(**item),
config.get_all_subscribe().all(),
)
)
all_subs: list[ConfigContent] = [ConfigContent(**item) for item in config.get_all_subscribe().all()]
async with AsyncSession(get_engine()) as sess:
user_to_create = []
subscribe_to_create = []
@@ -37,8 +32,7 @@ async def data_migrate():
if key in user_sub_set:
# a user subscribe a target twice
logger.error(
f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次,"
"随机采用了一个订阅"
f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次,随机采用了一个订阅" # noqa: E501
)
continue
user_sub_set.add(key)
@@ -69,11 +63,7 @@ async def data_migrate():
tags=sub["tags"],
)
subscribe_to_create.append(subscribe_obj)
sess.add_all(
user_to_create
+ list(map(lambda x: x[0], platform_target_map.values()))
+ subscribe_to_create
)
sess.add_all(user_to_create + [x[0] for x in platform_target_map.values()] + subscribe_to_create)
await sess.commit()
drop()
logger.info("migrate success")
@@ -1,7 +1,7 @@
"""init db
Revision ID: 0571870f5222
Revises:
Revises:
Create Date: 2022-03-21 19:18:13.762626
"""
@@ -5,7 +5,6 @@ Revises: 5f3370328e44
Create Date: 2023-01-15 19:04:54.987491
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
@@ -5,7 +5,6 @@ Revises: 0571870f5222
Create Date: 2022-03-26 19:46:50.910721
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
@@ -18,14 +17,10 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("subscribe", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique-subscribe-constraint", ["target_id", "user_id"]
)
batch_op.create_unique_constraint("unique-subscribe-constraint", ["target_id", "user_id"])
with op.batch_alter_table("target", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique-target-constraint", ["target", "platform_name"]
)
batch_op.create_unique_constraint("unique-target-constraint", ["target", "platform_name"])
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"])
@@ -1,15 +1,14 @@
from abc import ABC
from nonebot_plugin_saa.utils import AllSupportedPlatformTarget as UserInfo
from pydantic import BaseModel
from nonebot_plugin_saa.utils import AllSupportedPlatformTarget as UserInfo
from ....types import Category, Tag
from ....types import Tag, Category
class NBESFBase(BaseModel, ABC):
version: int # 表示nbesf格式版本,有效版本从1开始
groups: list = list()
groups: list = []
class Config:
orm_mode = True
+10 -9
View File
@@ -1,25 +1,26 @@
from typing import cast
from collections import defaultdict
from typing import Callable, cast
from collections.abc import Callable
from nonebot.log import logger
from nonebot_plugin_datastore.db import create_session
from nonebot_plugin_saa import PlatformTarget
from sqlalchemy import select
from sqlalchemy.orm.strategy_options import selectinload
from nonebot.log import logger
from sqlalchemy.sql.selectable import Select
from nonebot_plugin_saa import PlatformTarget
from nonebot_plugin_datastore.db import create_session
from sqlalchemy.orm.strategy_options import selectinload
from ..db_model import Subscribe, User
from .nbesf_model import NBESFBase, v1, v2
from .utils import NBESFVerMatchErr
from ..db_model import User, Subscribe
from .nbesf_model import NBESFBase, v1, v2
async def subscribes_export(selector: Callable[[Select], Select]) -> v2.SubGroup:
"""
将Bison订阅导出为 Nonebot Bison Exchangable Subscribes File 标准格式的 SubGroup 类型数据
selector:
对 sqlalchemy Select 对象的操作函数,用于限定查询范围 e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group")
对 sqlalchemy Select 对象的操作函数,用于限定查询范围
e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group")
"""
async with create_session() as sess:
sub_stmt = select(Subscribe).join(User)