From dba8f2a9cb543a638d87e847a0e824261bae4b85 Mon Sep 17 00:00:00 2001
From: Azide
Date: Sun, 16 Jul 2023 00:22:20 +0800
Subject: [PATCH] =?UTF-8?q?:art:=20=E6=8C=89ruff=E7=9A=84=E6=A3=80?=
=?UTF-8?q?=E6=9F=A5=E8=B0=83=E6=95=B4=E7=A8=8B=E5=BA=8F=E4=BB=A3=E7=A0=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
nonebot_bison/__init__.py | 20 +--
nonebot_bison/admin_page/__init__.py | 30 ++--
nonebot_bison/admin_page/api.py | 83 +++++------
nonebot_bison/admin_page/jwt.py | 7 +-
nonebot_bison/admin_page/token_manager.py | 3 +-
nonebot_bison/config/__init__.py | 6 +-
nonebot_bison/config/config_legacy.py | 73 ++++------
nonebot_bison/config/db_config.py | 129 ++++++------------
nonebot_bison/config/db_migration.py | 20 +--
.../config/migrations/0571870f5222_init_db.py | 2 +-
.../migrations/5da28f6facb3_rename_tables.py | 1 -
.../migrations/c97c445e2bdb_add_constraint.py | 9 +-
.../config/subs_io/nbesf_model/base.py | 7 +-
nonebot_bison/config/subs_io/subs_io.py | 19 +--
nonebot_bison/platform/__init__.py | 11 +-
nonebot_bison/platform/arknights.py | 65 +++------
nonebot_bison/platform/bilibili.py | 62 +++------
nonebot_bison/platform/ff14.py | 9 +-
nonebot_bison/platform/mcbbsnews.py | 14 +-
nonebot_bison/platform/ncm.py | 33 ++---
nonebot_bison/platform/platform.py | 82 ++++-------
nonebot_bison/platform/rss.py | 11 +-
nonebot_bison/platform/weibo.py | 90 +++++-------
nonebot_bison/plugin_config.py | 11 +-
nonebot_bison/post/__init__.py | 2 +-
nonebot_bison/post/abstract_post.py | 17 +--
nonebot_bison/post/custom_post.py | 18 +--
nonebot_bison/post/post.py | 54 +++-----
nonebot_bison/scheduler/__init__.py | 4 +-
nonebot_bison/scheduler/manager.py | 20 ++-
nonebot_bison/scheduler/scheduler.py | 53 +++----
nonebot_bison/script/cli.py | 20 ++-
nonebot_bison/send.py | 7 +-
nonebot_bison/sub_manager/add_sub.py | 57 +++-----
nonebot_bison/sub_manager/del_sub.py | 30 +---
nonebot_bison/sub_manager/query_sub.py | 21 +--
nonebot_bison/sub_manager/utils.py | 18 ++-
nonebot_bison/utils/__init__.py | 25 ++--
nonebot_bison/utils/context.py | 11 +-
nonebot_bison/utils/http.py | 8 +-
nonebot_bison/utils/scheduler_config.py | 7 +-
pyproject.toml | 2 +-
42 files changed, 414 insertions(+), 757 deletions(-)
diff --git a/nonebot_bison/__init__.py b/nonebot_bison/__init__.py
index 3b7db3f..46dde3f 100644
--- a/nonebot_bison/__init__.py
+++ b/nonebot_bison/__init__.py
@@ -6,25 +6,18 @@ require("nonebot_plugin_saa")
import nonebot_plugin_saa
-from . import (
- admin_page,
- bootstrap,
- config,
- platform,
- post,
- scheduler,
- send,
- sub_manager,
- types,
- utils,
-)
from .plugin_config import PlugConfig, plugin_config
+from . import post, send, types, utils, config, platform, bootstrap, scheduler, admin_page, sub_manager
__help__version__ = "0.7.3"
nonebot_plugin_saa.enable_auto_select_bot()
__help__plugin__name__ = "nonebot_bison"
-__usage__ = f"本bot可以提供b站、微博等社交媒体的消息订阅,详情请查看本bot文档,或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,发送“查询订阅”或“删除订阅”管理订阅"
+__usage__ = (
+ "本bot可以提供b站、微博等社交媒体的消息订阅,详情请查看本bot文档,"
+ f"或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,"
+ f"发送“查询订阅”或“删除订阅”管理订阅"
+)
__supported_adapters__ = nonebot_plugin_saa.__plugin_meta__.supported_adapters
@@ -41,6 +34,7 @@ __plugin_meta__ = PluginMetadata(
__all__ = [
"admin_page",
+ "bootstrap",
"config",
"sub_manager",
"post",
diff --git a/nonebot_bison/admin_page/__init__.py b/nonebot_bison/admin_page/__init__.py
index b6ada13..661c57e 100644
--- a/nonebot_bison/admin_page/__init__.py
+++ b/nonebot_bison/admin_page/__init__.py
@@ -1,17 +1,16 @@
import os
from pathlib import Path
-from typing import Union
-from nonebot import get_driver, on_command
-from nonebot.adapters.onebot.v11 import Bot
-from nonebot.adapters.onebot.v11.event import PrivateMessageEvent
-from nonebot.drivers.fastapi import Driver
from nonebot.log import logger
from nonebot.rule import to_me
from nonebot.typing import T_State
+from nonebot import get_driver, on_command
+from nonebot.drivers.fastapi import Driver
+from nonebot.adapters.onebot.v11 import Bot
+from nonebot.adapters.onebot.v11.event import PrivateMessageEvent
-from ..plugin_config import plugin_config
from .api import router as api_router
+from ..plugin_config import plugin_config
from .token_manager import token_manager as tm
STATIC_PATH = (Path(__file__).parent / "dist").resolve()
@@ -28,11 +27,9 @@ def init_fastapi():
class SinglePageApplication(StaticFiles):
def __init__(self, directory: os.PathLike, index="index.html"):
self.index = index
- super().__init__(
- directory=directory, packages=None, html=True, check_dir=True
- )
+ super().__init__(directory=directory, packages=None, html=True, check_dir=True)
- def lookup_path(self, path: str) -> tuple[str, Union[os.stat_result, None]]:
+ def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]:
full_path, stat_res = super().lookup_path(path)
if stat_res is None:
return super().lookup_path(self.index)
@@ -45,9 +42,7 @@ def init_fastapi():
description="nonebot-bison webui and api",
)
nonebot_app.include_router(api_router)
- nonebot_app.mount(
- "/", SinglePageApplication(directory=static_path), name="bison-frontend"
- )
+ nonebot_app.mount("/", SinglePageApplication(directory=static_path), name="bison-frontend")
app = driver.server_app
app.mount("/bison", nonebot_app, "nonebot-bison")
@@ -63,10 +58,9 @@ def init_fastapi():
if host in ["0.0.0.0", "127.0.0.1"]:
host = "localhost"
logger.opt(colors=True).info(
- f"Nonebot Bison frontend will be running at: "
- f"http://{host}:{port}/bison"
+ f"Nonebot Bison frontend will be running at: " f"http://{host}:{port}/bison"
)
- logger.opt(colors=True).info(f"该页面不能被直接访问,请私聊bot 后台管理 以获取可访问地址")
+ logger.opt(colors=True).info("该页面不能被直接访问,请私聊bot 后台管理 以获取可访问地址")
def register_get_token_handler():
@@ -93,6 +87,4 @@ if (STATIC_PATH / "index.html").exists():
else:
logger.warning("your driver is not fastapi, webui feature will be disabled")
else:
- logger.warning(
- "Frontend file not found, please compile it or use docker or pypi version"
- )
+ logger.warning("Frontend file not found, please compile it or use docker or pypi version")
diff --git a/nonebot_bison/admin_page/api.py b/nonebot_bison/admin_page/api.py
index 5a4fb81..4c299af 100644
--- a/nonebot_bison/admin_page/api.py
+++ b/nonebot_bison/admin_page/api.py
@@ -1,35 +1,30 @@
import nonebot
from fastapi import status
-from fastapi.exceptions import HTTPException
-from fastapi.param_functions import Depends
from fastapi.routing import APIRouter
-from fastapi.security.oauth2 import OAuth2PasswordBearer
+from fastapi.param_functions import Depends
+from fastapi.exceptions import HTTPException
from nonebot_plugin_saa import TargetQQGroup
+from fastapi.security.oauth2 import OAuth2PasswordBearer
from nonebot_plugin_saa.utils.auto_select_bot import get_bot
-from ..apis import check_sub_target
-from ..config import (
- NoSuchSubscribeException,
- NoSuchTargetException,
- NoSuchUserException,
- config,
-)
-from ..config.db_config import SubscribeDupException
-from ..platform import platform_manager
-from ..types import Target as T_Target
from ..types import WeightConfig
-from ..utils.get_bot import get_groups
+from ..apis import check_sub_target
from .jwt import load_jwt, pack_jwt
+from ..types import Target as T_Target
+from ..utils.get_bot import get_groups
+from ..platform import platform_manager
from .token_manager import token_manager
+from ..config.db_config import SubscribeDupException
+from ..config import NoSuchUserException, NoSuchTargetException, NoSuchSubscribeException, config
from .types import (
- AddSubscribeReq,
+ TokenResp,
GlobalConf,
- PlatformConfig,
StatusResp,
+ SubscribeResp,
+ PlatformConfig,
+ AddSubscribeReq,
SubscribeConfig,
SubscribeGroupDetail,
- SubscribeResp,
- TokenResp,
)
router = APIRouter(prefix="/api", tags=["api"])
@@ -44,9 +39,7 @@ async def get_jwt_obj(token: str = Depends(oath_scheme)):
return obj
-async def check_group_permission(
- groupNumber: int, token_obj: dict = Depends(get_jwt_obj)
-):
+async def check_group_permission(groupNumber: int, token_obj: dict = Depends(get_jwt_obj)):
groups = token_obj["groups"]
for group in groups:
if int(groupNumber) == group["id"]:
@@ -95,15 +88,13 @@ async def auth(token: str) -> TokenResp:
jwt_obj = {
"id": qq,
"type": "admin",
- "groups": list(
- map(
- lambda info: {
- "id": info["group_id"],
- "name": info["group_name"],
- },
- await get_groups(),
- )
- ),
+ "groups": [
+ {
+ "id": info["group_id"],
+ "name": info["group_name"],
+ }
+ for info in await get_groups()
+ ],
}
ret_obj = TokenResp(
type="admin",
@@ -134,18 +125,16 @@ async def get_subs_info(jwt_obj: dict = Depends(get_jwt_obj)) -> SubscribeResp:
for group in groups:
group_id = group["id"]
raw_subs = await config.list_subscribe(TargetQQGroup(group_id=group_id))
- subs = list(
- map(
- lambda sub: SubscribeConfig(
- platformName=sub.target.platform_name,
- targetName=sub.target.target_name,
- cats=sub.categories,
- tags=sub.tags,
- target=sub.target.target,
- ),
- raw_subs,
+ subs = [
+ SubscribeConfig(
+ platformName=sub.target.platform_name,
+ targetName=sub.target.target_name,
+ cats=sub.categories,
+ tags=sub.tags,
+ target=sub.target.target,
)
- )
+ for sub in raw_subs
+ ]
res[group_id] = SubscribeGroupDetail(name=group["name"], subscribes=subs)
return res
@@ -174,9 +163,7 @@ async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp:
@router.delete("/subs", dependencies=[Depends(check_group_permission)])
async def del_group_sub(groupNumber: int, platformName: str, target: str):
try:
- await config.del_subscribe(
- TargetQQGroup(group_id=groupNumber), target, platformName
- )
+ await config.del_subscribe(TargetQQGroup(group_id=groupNumber), target, platformName)
except (NoSuchUserException, NoSuchSubscribeException):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such user or subscribe")
return StatusResp(ok=True, msg="")
@@ -204,13 +191,9 @@ async def get_weight_config():
@router.put("/weight", dependencies=[Depends(check_is_superuser)])
-async def update_weigth_config(
- platformName: str, target: str, weight_config: WeightConfig
-):
+async def update_weigth_config(platformName: str, target: str, weight_config: WeightConfig):
try:
- await config.update_time_weight_config(
- T_Target(target), platformName, weight_config
- )
+ await config.update_time_weight_config(T_Target(target), platformName, weight_config)
except NoSuchTargetException:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such subscribe")
return StatusResp(ok=True, msg="")
diff --git a/nonebot_bison/admin_page/jwt.py b/nonebot_bison/admin_page/jwt.py
index 661621a..866c184 100644
--- a/nonebot_bison/admin_page/jwt.py
+++ b/nonebot_bison/admin_page/jwt.py
@@ -1,7 +1,6 @@
-import datetime
import random
import string
-from typing import Optional
+import datetime
import jwt
@@ -16,8 +15,8 @@ def pack_jwt(obj: dict) -> str:
)
-def load_jwt(token: str) -> Optional[dict]:
+def load_jwt(token: str) -> dict | None:
try:
return jwt.decode(token, _key, algorithms=["HS256"])
- except:
+ except Exception:
return None
diff --git a/nonebot_bison/admin_page/token_manager.py b/nonebot_bison/admin_page/token_manager.py
index e540656..bb62d0a 100644
--- a/nonebot_bison/admin_page/token_manager.py
+++ b/nonebot_bison/admin_page/token_manager.py
@@ -1,6 +1,5 @@
import random
import string
-from typing import Optional
from expiringdict import ExpiringDict
@@ -9,7 +8,7 @@ class TokenManager:
def __init__(self):
self.token_manager = ExpiringDict(max_len=100, max_age_seconds=60 * 10)
- def get_user(self, token: str) -> Optional[tuple]:
+ def get_user(self, token: str) -> tuple | None:
res = self.token_manager.get(token)
assert res is None or isinstance(res, tuple)
return res
diff --git a/nonebot_bison/config/__init__.py b/nonebot_bison/config/__init__.py
index 2fb9151..a04d41f 100644
--- a/nonebot_bison/config/__init__.py
+++ b/nonebot_bison/config/__init__.py
@@ -1,2 +1,4 @@
-from .db_config import config
-from .utils import NoSuchSubscribeException, NoSuchTargetException, NoSuchUserException
+from .db_config import config as config
+from .utils import NoSuchUserException as NoSuchUserException
+from .utils import NoSuchTargetException as NoSuchTargetException
+from .utils import NoSuchSubscribeException as NoSuchSubscribeException
diff --git a/nonebot_bison/config/config_legacy.py b/nonebot_bison/config/config_legacy.py
index d892b5c..24e7e4d 100644
--- a/nonebot_bison/config/config_legacy.py
+++ b/nonebot_bison/config/config_legacy.py
@@ -1,20 +1,19 @@
-import json
import os
-from collections import defaultdict
-from datetime import datetime
+import json
from os import path
from pathlib import Path
-from typing import DefaultDict, Literal, Mapping, TypedDict
+from datetime import datetime
+from collections import defaultdict
+from typing import Literal, TypedDict
-import nonebot
from nonebot.log import logger
from tinydb import Query, TinyDB
+from ..utils import Singleton
+from ..types import User, Target
from ..platform import platform_manager
from ..plugin_config import plugin_config
-from ..types import Target, User
-from ..utils import Singleton
-from .utils import NoSuchSubscribeException, NoSuchUserException
+from .utils import NoSuchUserException, NoSuchSubscribeException
supported_target_type = platform_manager.keys()
@@ -89,17 +88,16 @@ class Config(metaclass=Singleton):
self.target_user_cat_cache = {}
self.target_user_tag_cache = {}
self.target_list = {}
- self.next_index: DefaultDict[str, int] = defaultdict(lambda: 0)
+ self.next_index: defaultdict[str, int] = defaultdict(lambda: 0)
else:
self.available = False
- def add_subscribe(
- self, user, user_type, target, target_name, target_type, cats, tags
- ):
+ def add_subscribe(self, user, user_type, target, target_name, target_type, cats, tags):
user_query = Query()
query = (user_query.user == user) & (user_query.user_type == user_type)
if user_data := self.user_target.get(query):
# update
+ assert not isinstance(user_data, list)
subs: list = user_data.get("subs", [])
subs.append(
{
@@ -132,9 +130,8 @@ class Config(metaclass=Singleton):
def list_subscribe(self, user, user_type) -> list[SubscribeContent]:
query = Query()
- if user_sub := self.user_target.get(
- (query.user == user) & (query.user_type == user_type)
- ):
+ if user_sub := self.user_target.get((query.user == user) & (query.user_type == user_type)):
+ assert not isinstance(user_sub, list)
return user_sub["subs"]
return []
@@ -146,6 +143,7 @@ class Config(metaclass=Singleton):
query = (user_query.user == user) & (user_query.user_type == user_type)
if not (query_res := self.user_target.get(query)):
raise NoSuchUserException()
+ assert not isinstance(query_res, list)
subs = query_res.get("subs", [])
for idx, sub in enumerate(subs):
if sub.get("target") == target and sub.get("target_type") == target_type:
@@ -155,13 +153,12 @@ class Config(metaclass=Singleton):
return
raise NoSuchSubscribeException()
- def update_subscribe(
- self, user, user_type, target, target_name, target_type, cats, tags
- ):
+ def update_subscribe(self, user, user_type, target, target_name, target_type, cats, tags):
user_query = Query()
query = (user_query.user == user) & (user_query.user_type == user_type)
if user_data := self.user_target.get(query):
# update
+ assert not isinstance(user_data, list)
subs: list = user_data.get("subs", [])
find_flag = False
for item in subs:
@@ -182,19 +179,13 @@ class Config(metaclass=Singleton):
def update_send_cache(self):
res = {target_type: defaultdict(list) for target_type in supported_target_type}
- cat_res = {
- target_type: defaultdict(lambda: defaultdict(list))
- for target_type in supported_target_type
- }
- tag_res = {
- target_type: defaultdict(lambda: defaultdict(list))
- for target_type in supported_target_type
- }
+ cat_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
+ tag_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
# res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
to_del = []
for user in self.user_target.all():
for sub in user.get("subs", []):
- if not sub.get("target_type") in supported_target_type:
+ if sub.get("target_type") not in supported_target_type:
to_del.append(
{
"user": user["user"],
@@ -204,36 +195,28 @@ class Config(metaclass=Singleton):
}
)
continue
- res[sub["target_type"]][sub["target"]].append(
- User(user["user"], user["user_type"])
- )
- cat_res[sub["target_type"]][sub["target"]][
- "{}-{}".format(user["user_type"], user["user"])
- ] = sub["cats"]
- tag_res[sub["target_type"]][sub["target"]][
- "{}-{}".format(user["user_type"], user["user"])
- ] = sub["tags"]
+ res[sub["target_type"]][sub["target"]].append(User(user["user"], user["user_type"]))
+ cat_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[
+ "cats"
+ ]
+ tag_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[
+ "tags"
+ ]
self.target_user_cache = res
self.target_user_cat_cache = cat_res
self.target_user_tag_cache = tag_res
for target_type in self.target_user_cache:
- self.target_list[target_type] = list(
- self.target_user_cache[target_type].keys()
- )
+ self.target_list[target_type] = list(self.target_user_cache[target_type].keys())
logger.info(f"Deleting {to_del}")
for d in to_del:
self.del_subscribe(**d)
def get_sub_category(self, target_type, target, user_type, user):
- return self.target_user_cat_cache[target_type][target][
- "{}-{}".format(user_type, user)
- ]
+ return self.target_user_cat_cache[target_type][target][f"{user_type}-{user}"]
def get_sub_tags(self, target_type, target, user_type, user):
- return self.target_user_tag_cache[target_type][target][
- "{}-{}".format(user_type, user)
- ]
+ return self.target_user_tag_cache[target_type][target][f"{user_type}-{user}"]
def get_next_target(self, target_type):
# FIXME 插入或删除target后对队列的影响(但是并不是大问题
diff --git a/nonebot_bison/config/db_config.py b/nonebot_bison/config/db_config.py
index 38ef9af..ef6cc6b 100644
--- a/nonebot_bison/config/db_config.py
+++ b/nonebot_bison/config/db_config.py
@@ -1,19 +1,19 @@
import asyncio
from collections import defaultdict
-from datetime import datetime, time
-from typing import Awaitable, Callable, Optional, Sequence
+from datetime import time, datetime
+from collections.abc import Callable, Sequence, Awaitable
-from nonebot_plugin_datastore import create_session
-from nonebot_plugin_saa import PlatformTarget
-from sqlalchemy import delete, func, select
-from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import selectinload
+from sqlalchemy.exc import IntegrityError
+from sqlalchemy import func, delete, select
+from nonebot_plugin_saa import PlatformTarget
+from nonebot_plugin_datastore import create_session
-from ..types import Category, PlatformWeightConfigResp, Tag
+from ..types import Tag
from ..types import Target as T_Target
-from ..types import TimeWeightConfig, UserSubInfo, WeightConfig
-from .db_model import ScheduleTimeWeight, Subscribe, Target, User
from .utils import NoSuchTargetException
+from .db_model import User, Target, Subscribe, ScheduleTimeWeight
+from ..types import Category, UserSubInfo, WeightConfig, TimeWeightConfig, PlatformWeightConfigResp
def _get_time():
@@ -48,23 +48,17 @@ class DBConfig:
):
async with create_session() as session:
db_user_stmt = select(User).where(User.user_target == user.dict())
- db_user: Optional[User] = await session.scalar(db_user_stmt)
+ db_user: User | None = await session.scalar(db_user_stmt)
if not db_user:
db_user = User(user_target=user.dict())
session.add(db_user)
db_target_stmt = (
- select(Target)
- .where(Target.platform_name == platform_name)
- .where(Target.target == target)
+ select(Target).where(Target.platform_name == platform_name).where(Target.target == target)
)
- db_target: Optional[Target] = await session.scalar(db_target_stmt)
+ db_target: Target | None = await session.scalar(db_target_stmt)
if not db_target:
- db_target = Target(
- target=target, platform_name=platform_name, target_name=target_name
- )
- await asyncio.gather(
- *[hook(platform_name, target) for hook in self.add_target_hook]
- )
+ db_target = Target(target=target, platform_name=platform_name, target_name=target_name)
+ await asyncio.gather(*[hook(platform_name, target) for hook in self.add_target_hook])
else:
db_target.target_name = target_name
subscribe = Subscribe(
@@ -96,44 +90,25 @@ class DBConfig:
"""获取数据库中带有user、target信息的subscribe数据"""
async with create_session() as session:
query_stmt = (
- select(Subscribe)
- .join(User)
- .options(selectinload(Subscribe.target), selectinload(Subscribe.user))
+ select(Subscribe).join(User).options(selectinload(Subscribe.target), selectinload(Subscribe.user))
)
subs = (await session.scalars(query_stmt)).all()
return subs
- async def del_subscribe(
- self, user: PlatformTarget, target: str, platform_name: str
- ):
+ async def del_subscribe(self, user: PlatformTarget, target: str, platform_name: str):
async with create_session() as session:
- user_obj = await session.scalar(
- select(User).where(User.user_target == user.dict())
- )
+ user_obj = await session.scalar(select(User).where(User.user_target == user.dict()))
target_obj = await session.scalar(
- select(Target).where(
- Target.platform_name == platform_name, Target.target == target
- )
- )
- await session.execute(
- delete(Subscribe).where(
- Subscribe.user == user_obj, Subscribe.target == target_obj
- )
+ select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
+ await session.execute(delete(Subscribe).where(Subscribe.user == user_obj, Subscribe.target == target_obj))
target_count = await session.scalar(
- select(func.count())
- .select_from(Subscribe)
- .where(Subscribe.target == target_obj)
+ select(func.count()).select_from(Subscribe).where(Subscribe.target == target_obj)
)
if target_count == 0:
# delete empty target
- await asyncio.gather(
- *[
- hook(platform_name, T_Target(target))
- for hook in self.delete_target_hook
- ]
- )
+ await asyncio.gather(*[hook(platform_name, T_Target(target)) for hook in self.delete_target_hook])
await session.commit()
async def update_subscribe(
@@ -165,29 +140,22 @@ class DBConfig:
async def get_platform_target(self, platform_name: str) -> Sequence[Target]:
async with create_session() as sess:
subq = select(Subscribe.target_id).distinct().subquery()
- query = (
- select(Target).join(subq).where(Target.platform_name == platform_name)
- )
+ query = select(Target).join(subq).where(Target.platform_name == platform_name)
return (await sess.scalars(query)).all()
- async def get_time_weight_config(
- self, target: T_Target, platform_name: str
- ) -> WeightConfig:
+ async def get_time_weight_config(self, target: T_Target, platform_name: str) -> WeightConfig:
async with create_session() as sess:
time_weight_conf = (
await sess.scalars(
select(ScheduleTimeWeight)
- .where(
- Target.platform_name == platform_name, Target.target == target
- )
+ .where(Target.platform_name == platform_name, Target.target == target)
.join(Target)
)
).all()
targetObj = await sess.scalar(
- select(Target).where(
- Target.platform_name == platform_name, Target.target == target
- )
+ select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
+ assert targetObj
return WeightConfig(
default=targetObj.default_schedule_weight,
time_config=[
@@ -200,22 +168,16 @@ class DBConfig:
],
)
- async def update_time_weight_config(
- self, target: T_Target, platform_name: str, conf: WeightConfig
- ):
+ async def update_time_weight_config(self, target: T_Target, platform_name: str, conf: WeightConfig):
async with create_session() as sess:
targetObj = await sess.scalar(
- select(Target).where(
- Target.platform_name == platform_name, Target.target == target
- )
+ select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
if not targetObj:
raise NoSuchTargetException()
target_id = targetObj.id
targetObj.default_schedule_weight = conf.default
- delete_statement = delete(ScheduleTimeWeight).where(
- ScheduleTimeWeight.target_id == target_id
- )
+ delete_statement = delete(ScheduleTimeWeight).where(ScheduleTimeWeight.target_id == target_id)
await sess.execute(delete_statement)
for time_conf in conf.time_config:
new_conf = ScheduleTimeWeight(
@@ -243,18 +205,13 @@ class DBConfig:
key = f"{target.platform_name}-{target.target}"
weight = target.default_schedule_weight
for time_conf in target.time_weight:
- if (
- time_conf.start_time <= cur_time
- and time_conf.end_time > cur_time
- ):
+ if time_conf.start_time <= cur_time and time_conf.end_time > cur_time:
weight = time_conf.weight
break
res[key] = weight
return res
- async def get_platform_target_subscribers(
- self, platform_name: str, target: T_Target
- ) -> list[UserSubInfo]:
+ async def get_platform_target_subscribers(self, platform_name: str, target: T_Target) -> list[UserSubInfo]:
async with create_session() as sess:
query = (
select(Subscribe)
@@ -263,16 +220,14 @@ class DBConfig:
.options(selectinload(Subscribe.user))
)
subsribes = (await sess.scalars(query)).all()
- return list(
- map(
- lambda subscribe: UserSubInfo(
- PlatformTarget.deserialize(subscribe.user.user_target),
- subscribe.categories,
- subscribe.tags,
- ),
- subsribes,
+ return [
+ UserSubInfo(
+ PlatformTarget.deserialize(subscribe.user.user_target),
+ subscribe.categories,
+ subscribe.tags,
)
- )
+ for subscribe in subsribes
+ ]
async def get_all_weight_config(
self,
@@ -281,9 +236,7 @@ class DBConfig:
async with create_session() as sess:
query = select(Target)
targets = (await sess.scalars(query)).all()
- query = select(ScheduleTimeWeight).options(
- selectinload(ScheduleTimeWeight.target)
- )
+ query = select(ScheduleTimeWeight).options(selectinload(ScheduleTimeWeight.target))
time_weights = (await sess.scalars(query)).all()
for target in targets:
@@ -293,9 +246,7 @@ class DBConfig:
target=T_Target(target.target),
target_name=target.target_name,
platform_name=platform_name,
- weight=WeightConfig(
- default=target.default_schedule_weight, time_config=[]
- ),
+ weight=WeightConfig(default=target.default_schedule_weight, time_config=[]),
)
for time_weight_config in time_weights:
diff --git a/nonebot_bison/config/db_migration.py b/nonebot_bison/config/db_migration.py
index e20b1dc..08d3117 100644
--- a/nonebot_bison/config/db_migration.py
+++ b/nonebot_bison/config/db_migration.py
@@ -1,22 +1,17 @@
from nonebot.log import logger
from nonebot_plugin_datastore.db import get_engine
-from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from sqlalchemy.ext.asyncio.session import AsyncSession
+from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
+from .db_model import User, Target, Subscribe
from .config_legacy import Config, ConfigContent, drop
-from .db_model import Subscribe, Target, User
async def data_migrate():
config = Config()
if config.available:
logger.warning("You are still using legacy db, migrating to sqlite")
- all_subs: list[ConfigContent] = list(
- map(
- lambda item: ConfigContent(**item),
- config.get_all_subscribe().all(),
- )
- )
+ all_subs: list[ConfigContent] = [ConfigContent(**item) for item in config.get_all_subscribe().all()]
async with AsyncSession(get_engine()) as sess:
user_to_create = []
subscribe_to_create = []
@@ -37,8 +32,7 @@ async def data_migrate():
if key in user_sub_set:
# a user subscribe a target twice
logger.error(
- f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次,"
- "随机采用了一个订阅"
+ f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次,随机采用了一个订阅" # noqa: E501
)
continue
user_sub_set.add(key)
@@ -69,11 +63,7 @@ async def data_migrate():
tags=sub["tags"],
)
subscribe_to_create.append(subscribe_obj)
- sess.add_all(
- user_to_create
- + list(map(lambda x: x[0], platform_target_map.values()))
- + subscribe_to_create
- )
+ sess.add_all(user_to_create + [x[0] for x in platform_target_map.values()] + subscribe_to_create)
await sess.commit()
drop()
logger.info("migrate success")
diff --git a/nonebot_bison/config/migrations/0571870f5222_init_db.py b/nonebot_bison/config/migrations/0571870f5222_init_db.py
index d6e0c2c..347212a 100644
--- a/nonebot_bison/config/migrations/0571870f5222_init_db.py
+++ b/nonebot_bison/config/migrations/0571870f5222_init_db.py
@@ -1,7 +1,7 @@
"""init db
Revision ID: 0571870f5222
-Revises:
+Revises:
Create Date: 2022-03-21 19:18:13.762626
"""
diff --git a/nonebot_bison/config/migrations/5da28f6facb3_rename_tables.py b/nonebot_bison/config/migrations/5da28f6facb3_rename_tables.py
index c8eb5c7..20e7544 100644
--- a/nonebot_bison/config/migrations/5da28f6facb3_rename_tables.py
+++ b/nonebot_bison/config/migrations/5da28f6facb3_rename_tables.py
@@ -5,7 +5,6 @@ Revises: 5f3370328e44
Create Date: 2023-01-15 19:04:54.987491
"""
-import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
diff --git a/nonebot_bison/config/migrations/c97c445e2bdb_add_constraint.py b/nonebot_bison/config/migrations/c97c445e2bdb_add_constraint.py
index 9119d3b..807699e 100644
--- a/nonebot_bison/config/migrations/c97c445e2bdb_add_constraint.py
+++ b/nonebot_bison/config/migrations/c97c445e2bdb_add_constraint.py
@@ -5,7 +5,6 @@ Revises: 0571870f5222
Create Date: 2022-03-26 19:46:50.910721
"""
-import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
@@ -18,14 +17,10 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("subscribe", schema=None) as batch_op:
- batch_op.create_unique_constraint(
- "unique-subscribe-constraint", ["target_id", "user_id"]
- )
+ batch_op.create_unique_constraint("unique-subscribe-constraint", ["target_id", "user_id"])
with op.batch_alter_table("target", schema=None) as batch_op:
- batch_op.create_unique_constraint(
- "unique-target-constraint", ["target", "platform_name"]
- )
+ batch_op.create_unique_constraint("unique-target-constraint", ["target", "platform_name"])
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"])
diff --git a/nonebot_bison/config/subs_io/nbesf_model/base.py b/nonebot_bison/config/subs_io/nbesf_model/base.py
index 11ae2bb..f8e4b55 100644
--- a/nonebot_bison/config/subs_io/nbesf_model/base.py
+++ b/nonebot_bison/config/subs_io/nbesf_model/base.py
@@ -1,15 +1,14 @@
from abc import ABC
-from nonebot_plugin_saa.utils import AllSupportedPlatformTarget as UserInfo
from pydantic import BaseModel
+from nonebot_plugin_saa.utils import AllSupportedPlatformTarget as UserInfo
-from ....types import Category, Tag
+from ....types import Tag, Category
class NBESFBase(BaseModel, ABC):
-
version: int # 表示nbesf格式版本,有效版本从1开始
- groups: list = list()
+ groups: list = []
class Config:
orm_mode = True
diff --git a/nonebot_bison/config/subs_io/subs_io.py b/nonebot_bison/config/subs_io/subs_io.py
index 9a16472..21c1310 100644
--- a/nonebot_bison/config/subs_io/subs_io.py
+++ b/nonebot_bison/config/subs_io/subs_io.py
@@ -1,25 +1,26 @@
+from typing import cast
from collections import defaultdict
-from typing import Callable, cast
+from collections.abc import Callable
-from nonebot.log import logger
-from nonebot_plugin_datastore.db import create_session
-from nonebot_plugin_saa import PlatformTarget
from sqlalchemy import select
-from sqlalchemy.orm.strategy_options import selectinload
+from nonebot.log import logger
from sqlalchemy.sql.selectable import Select
+from nonebot_plugin_saa import PlatformTarget
+from nonebot_plugin_datastore.db import create_session
+from sqlalchemy.orm.strategy_options import selectinload
-from ..db_model import Subscribe, User
-from .nbesf_model import NBESFBase, v1, v2
from .utils import NBESFVerMatchErr
+from ..db_model import User, Subscribe
+from .nbesf_model import NBESFBase, v1, v2
async def subscribes_export(selector: Callable[[Select], Select]) -> v2.SubGroup:
-
"""
将Bison订阅导出为 Nonebot Bison Exchangable Subscribes File 标准格式的 SubGroup 类型数据
selector:
- 对 sqlalchemy Select 对象的操作函数,用于限定查询范围 e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group")
+ 对 sqlalchemy Select 对象的操作函数,用于限定查询范围
+ e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group")
"""
async with create_session() as sess:
sub_stmt = select(Subscribe).join(User)
diff --git a/nonebot_bison/platform/__init__.py b/nonebot_bison/platform/__init__.py
index e8d7186..c99ce12 100644
--- a/nonebot_bison/platform/__init__.py
+++ b/nonebot_bison/platform/__init__.py
@@ -1,23 +1,22 @@
-from collections import defaultdict
-from importlib import import_module
from pathlib import Path
from pkgutil import iter_modules
-from typing import DefaultDict, Type
+from collections import defaultdict
+from importlib import import_module
from .platform import Platform, make_no_target_group
_package_dir = str(Path(__file__).resolve().parent)
-for (_, module_name, _) in iter_modules([_package_dir]):
+for _, module_name, _ in iter_modules([_package_dir]):
import_module(f"{__name__}.{module_name}")
-_platform_list: DefaultDict[str, list[Type[Platform]]] = defaultdict(list)
+_platform_list: defaultdict[str, list[type[Platform]]] = defaultdict(list)
for _platform in Platform.registry:
if not _platform.enabled:
continue
_platform_list[_platform.platform_name].append(_platform)
-platform_manager: dict[str, Type[Platform]] = dict()
+platform_manager: dict[str, type[Platform]] = {}
for name, platform_list in _platform_list.items():
if len(platform_list) == 1:
platform_manager[name] = platform_list[0]
diff --git a/nonebot_bison/platform/arknights.py b/nonebot_bison/platform/arknights.py
index 2513de5..00f3fbd 100644
--- a/nonebot_bison/platform/arknights.py
+++ b/nonebot_bison/platform/arknights.py
@@ -1,25 +1,23 @@
import json
-from typing import Any, Optional
+from typing import Any
-from bs4 import BeautifulSoup as bs
from httpx import AsyncClient
from nonebot.plugin import require
+from bs4 import BeautifulSoup as bs
from ..post import Post
-from ..types import Category, RawPost, Target
+from ..types import Target, RawPost, Category
from ..utils.scheduler_config import SchedulerConfig
-from .platform import CategoryNotRecognize, NewMessage, StatusChange
+from .platform import NewMessage, StatusChange, CategoryNotRecognize
class ArknightsSchedConf(SchedulerConfig):
-
name = "arknights"
schedule_type = "interval"
schedule_setting = {"seconds": 30}
class Arknights(NewMessage):
-
categories = {1: "游戏公告"}
platform_name = "arknights"
name = "明日方舟游戏信息"
@@ -30,9 +28,7 @@ class Arknights(NewMessage):
has_target = False
@classmethod
- async def get_target_name(
- cls, client: AsyncClient, target: Target
- ) -> Optional[str]:
+ async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]:
@@ -92,7 +88,6 @@ class Arknights(NewMessage):
class AkVersion(StatusChange):
-
categories = {2: "更新信息"}
platform_name = "arknights"
name = "明日方舟游戏信息"
@@ -103,15 +98,11 @@ class AkVersion(StatusChange):
has_target = False
@classmethod
- async def get_target_name(
- cls, client: AsyncClient, target: Target
- ) -> Optional[str]:
+ async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
return "明日方舟游戏信息"
async def get_status(self, _):
- res_ver = await self.client.get(
- "https://ak-conf.hypergryph.com/config/prod/official/IOS/version"
- )
+ res_ver = await self.client.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version")
res_preanounce = await self.client.get(
"https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json"
)
@@ -121,20 +112,10 @@ class AkVersion(StatusChange):
def compare_status(self, _, old_status, new_status):
res = []
- if (
- old_status.get("preAnnounceType") == 2
- and new_status.get("preAnnounceType") == 0
- ):
- res.append(
- Post("arknights", text="登录界面维护公告上线(大概是开始维护了)", target_name="明日方舟更新信息")
- )
- elif (
- old_status.get("preAnnounceType") == 0
- and new_status.get("preAnnounceType") == 2
- ):
- res.append(
- Post("arknights", text="登录界面维护公告下线(大概是开服了,冲!)", target_name="明日方舟更新信息")
- )
+ if old_status.get("preAnnounceType") == 2 and new_status.get("preAnnounceType") == 0:
+ res.append(Post("arknights", text="登录界面维护公告上线(大概是开始维护了)", target_name="明日方舟更新信息")) # noqa: E501
+ elif old_status.get("preAnnounceType") == 0 and new_status.get("preAnnounceType") == 2:
+ res.append(Post("arknights", text="登录界面维护公告下线(大概是开服了,冲!)", target_name="明日方舟更新信息")) # noqa: E501
if old_status.get("clientVersion") != new_status.get("clientVersion"):
res.append(Post("arknights", text="游戏本体更新(大更新)", target_name="明日方舟更新信息"))
if old_status.get("resVersion") != new_status.get("resVersion"):
@@ -149,7 +130,6 @@ class AkVersion(StatusChange):
class MonsterSiren(NewMessage):
-
categories = {3: "塞壬唱片新闻"}
platform_name = "arknights"
name = "明日方舟游戏信息"
@@ -160,15 +140,11 @@ class MonsterSiren(NewMessage):
has_target = False
@classmethod
- async def get_target_name(
- cls, client: AsyncClient, target: Target
- ) -> Optional[str]:
+ async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]:
- raw_data = await self.client.get(
- "https://monster-siren.hypergryph.com/api/news"
- )
+ raw_data = await self.client.get("https://monster-siren.hypergryph.com/api/news")
return raw_data.json()["data"]["list"]
def get_id(self, post: RawPost) -> Any:
@@ -182,14 +158,12 @@ class MonsterSiren(NewMessage):
async def parse(self, raw_post: RawPost) -> Post:
url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}'
- res = await self.client.get(
- f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}'
- )
+ res = await self.client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}')
raw_data = res.json()
content = raw_data["data"]["content"]
content = content.replace("
", "\n")
soup = bs(content, "html.parser")
- imgs = list(map(lambda x: x["src"], soup("img")))
+ imgs = [x["src"] for x in soup("img")]
text = f'{raw_post["title"]}\n{soup.text.strip()}'
return Post(
"monster-siren",
@@ -203,7 +177,6 @@ class MonsterSiren(NewMessage):
class TerraHistoricusComic(NewMessage):
-
categories = {4: "泰拉记事社漫画"}
platform_name = "arknights"
name = "明日方舟游戏信息"
@@ -214,15 +187,11 @@ class TerraHistoricusComic(NewMessage):
has_target = False
@classmethod
- async def get_target_name(
- cls, client: AsyncClient, target: Target
- ) -> Optional[str]:
+ async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]:
- raw_data = await self.client.get(
- "https://terra-historicus.hypergryph.com/api/recentUpdate"
- )
+ raw_data = await self.client.get("https://terra-historicus.hypergryph.com/api/recentUpdate")
return raw_data.json()["data"]
def get_id(self, post: RawPost) -> Any:
diff --git a/nonebot_bison/platform/bilibili.py b/nonebot_bison/platform/bilibili.py
index 9769348..c81b0f1 100644
--- a/nonebot_bison/platform/bilibili.py
+++ b/nonebot_bison/platform/bilibili.py
@@ -1,14 +1,14 @@
-import json
import re
+import json
+from typing import Any
from copy import deepcopy
-from datetime import datetime, timedelta
from enum import Enum, unique
-from typing import Any, Literal, Optional
+from typing_extensions import Self
+from datetime import datetime, timedelta
from httpx import AsyncClient
from nonebot.log import logger
-from pydantic import BaseModel, Field
-from typing_extensions import Self
+from pydantic import Field, BaseModel
from ..post import Post
from ..types import ApiError, Category, RawPost, Tag, Target
@@ -25,9 +25,7 @@ class BilibiliSchedConf(SchedulerConfig):
cookie_expire_time = timedelta(hours=5)
def __init__(self):
- self._client_refresh_time = datetime(
- year=2000, month=1, day=1
- ) # an expired time
+ self._client_refresh_time = datetime(year=2000, month=1, day=1) # an expired time
super().__init__()
async def _init_session(self):
@@ -69,12 +67,8 @@ class Bilibili(NewMessage):
parse_target_promot = "请输入用户主页的链接"
@classmethod
- async def get_target_name(
- cls, client: AsyncClient, target: Target
- ) -> Optional[str]:
- res = await client.get(
- "https://api.bilibili.com/x/web-interface/card", params={"mid": target}
- )
+ async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
+ res = await client.get("https://api.bilibili.com/x/web-interface/card", params={"mid": target})
res.raise_for_status()
res_data = res.json()
if res_data["code"]:
@@ -129,12 +123,7 @@ class Bilibili(NewMessage):
return self._do_get_category(post_type)
def get_tags(self, raw_post: RawPost) -> list[Tag]:
- return [
- *map(
- lambda tp: tp["topic_name"],
- raw_post["display"]["topic_info"]["topic_details"],
- )
- ]
+ return [*(tp["topic_name"] for tp in raw_post["display"]["topic_info"]["topic_details"])]
def _get_info(self, post_type: Category, card) -> tuple[str, list]:
if post_type == 1:
@@ -178,24 +167,16 @@ class Bilibili(NewMessage):
url = ""
if post_type == 1:
# 一般动态
- url = "https://t.bilibili.com/{}".format(
- raw_post["desc"]["dynamic_id_str"]
- )
+ url = "https://t.bilibili.com/{}".format(raw_post["desc"]["dynamic_id_str"])
elif post_type == 2:
# 专栏文章
- url = "https://www.bilibili.com/read/cv{}".format(
- raw_post["desc"]["rid"]
- )
+ url = "https://www.bilibili.com/read/cv{}".format(raw_post["desc"]["rid"])
elif post_type == 3:
# 视频
- url = "https://www.bilibili.com/video/{}".format(
- raw_post["desc"]["bvid"]
- )
+ url = "https://www.bilibili.com/video/{}".format(raw_post["desc"]["bvid"])
elif post_type == 4:
# 纯文字
- url = "https://t.bilibili.com/{}".format(
- raw_post["desc"]["dynamic_id_str"]
- )
+ url = "https://t.bilibili.com/{}".format(raw_post["desc"]["dynamic_id_str"])
text, pic = self._get_info(post_type, card_content)
elif post_type == 5:
# 转发
@@ -261,10 +242,7 @@ class Bilibililive(StatusChange):
def get_live_action(self, old_info: Self) -> "Bilibililive.LiveAction":
status = Bilibililive.LiveStatus
action = Bilibililive.LiveAction
- if (
- old_info.live_status in [status.OFF, status.CYCLE]
- and self.live_status == status.ON
- ):
+ if old_info.live_status in [status.OFF, status.CYCLE] and self.live_status == status.ON:
return action.TURN_ON
elif old_info.live_status == status.ON and self.live_status in [
status.OFF,
@@ -281,12 +259,8 @@ class Bilibililive(StatusChange):
return action.OFF
@classmethod
- async def get_target_name(
- cls, client: AsyncClient, target: Target
- ) -> Optional[str]:
- res = await client.get(
- "https://api.bilibili.com/x/web-interface/card", params={"mid": target}
- )
+ async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
+ res = await client.get("https://api.bilibili.com/x/web-interface/card", params={"mid": target})
res_data = json.loads(res.text)
if res_data["code"]:
return None
@@ -382,9 +356,7 @@ class BilibiliBangumi(StatusChange):
_url = "https://api.bilibili.com/pgc/review/user"
@classmethod
- async def get_target_name(
- cls, client: AsyncClient, target: Target
- ) -> Optional[str]:
+ async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
res = await client.get(cls._url, params={"media_id": target})
res_data = res.json()
if res_data["code"]:
diff --git a/nonebot_bison/platform/ff14.py b/nonebot_bison/platform/ff14.py
index 61ebc24..c7af6d4 100644
--- a/nonebot_bison/platform/ff14.py
+++ b/nonebot_bison/platform/ff14.py
@@ -1,15 +1,14 @@
-from typing import Any, Optional
+from typing import Any
from httpx import AsyncClient
from ..post import Post
-from ..types import RawPost, Target
from ..utils import scheduler
from .platform import NewMessage
+from ..types import Target, RawPost
class FF14(NewMessage):
-
categories = {}
platform_name = "ff14"
name = "最终幻想XIV官方公告"
@@ -21,9 +20,7 @@ class FF14(NewMessage):
has_target = False
@classmethod
- async def get_target_name(
- cls, client: AsyncClient, target: Target
- ) -> Optional[str]:
+ async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
return "最终幻想XIV官方公告"
async def get_sub_list(self, _) -> list[RawPost]:
diff --git a/nonebot_bison/platform/mcbbsnews.py b/nonebot_bison/platform/mcbbsnews.py
index f8020ab..1784698 100644
--- a/nonebot_bison/platform/mcbbsnews.py
+++ b/nonebot_bison/platform/mcbbsnews.py
@@ -2,15 +2,15 @@ import re
import time
import traceback
-from bs4 import BeautifulSoup, Tag
from httpx import AsyncClient
from nonebot.log import logger
+from bs4 import Tag, BeautifulSoup
from nonebot.plugin import require
from ..post import Post
-from ..types import Category, RawPost, Target
+from ..types import Target, RawPost, Category
from ..utils import SchedulerConfig, http_client
-from .platform import CategoryNotRecognize, CategoryNotSupport, NewMessage
+from .platform import NewMessage, CategoryNotSupport, CategoryNotRecognize
class McbbsnewsSchedConf(SchedulerConfig):
@@ -134,9 +134,9 @@ class McbbsNews(NewMessage):
if categoty_name in category_values:
category_id = category_keys[category_values.index(categoty_name)]
elif categoty_name in known_category_values:
- raise CategoryNotSupport("McbbsNews订阅暂不支持 {}".format(categoty_name))
+ raise CategoryNotSupport(f"McbbsNews订阅暂不支持 {categoty_name}")
else:
- raise CategoryNotRecognize("Mcbbsnews订阅尚未识别 {}".format(categoty_name))
+ raise CategoryNotRecognize(f"Mcbbsnews订阅尚未识别 {categoty_name}")
return category_id
async def parse(self, post: RawPost) -> Post:
@@ -170,7 +170,7 @@ class McbbsNews(NewMessage):
一般而言每条新闻的长度都很可观,图片生成时间比较喜人
"""
require("nonebot_plugin_htmlrender")
- from nonebot_plugin_htmlrender import capture_element, text_to_pic
+ from nonebot_plugin_htmlrender import text_to_pic, capture_element
try:
assert url
@@ -181,7 +181,7 @@ class McbbsNews(NewMessage):
device_scale_factor=3,
)
assert pic_data
- except:
+ except Exception:
err_info = traceback.format_exc()
logger.warning(f"渲染错误:{err_info}")
diff --git a/nonebot_bison/platform/ncm.py b/nonebot_bison/platform/ncm.py
index 4688f22..34883f7 100644
--- a/nonebot_bison/platform/ncm.py
+++ b/nonebot_bison/platform/ncm.py
@@ -1,23 +1,21 @@
import re
-from typing import Any, Optional
+from typing import Any
from httpx import AsyncClient
from ..post import Post
-from ..types import ApiError, RawPost, Target
-from ..utils import SchedulerConfig
from .platform import NewMessage
+from ..utils import SchedulerConfig
+from ..types import Target, RawPost, ApiError
class NcmSchedConf(SchedulerConfig):
-
name = "music.163.com"
schedule_type = "interval"
schedule_setting = {"minutes": 1}
class NcmArtist(NewMessage):
-
categories = {}
platform_name = "ncm-artist"
enable_tag = False
@@ -29,11 +27,9 @@ class NcmArtist(NewMessage):
parse_target_promot = "请输入歌手主页(包含数字ID)的链接"
@classmethod
- async def get_target_name(
- cls, client: AsyncClient, target: Target
- ) -> Optional[str]:
+ async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
res = await client.get(
- "https://music.163.com/api/artist/albums/{}".format(target),
+ f"https://music.163.com/api/artist/albums/{target}",
headers={"Referer": "https://music.163.com/"},
)
res_data = res.json()
@@ -45,16 +41,14 @@ class NcmArtist(NewMessage):
async def parse_target(cls, target_text: str) -> Target:
if re.match(r"^\d+$", target_text):
return Target(target_text)
- elif match := re.match(
- r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text
- ):
+ elif match := re.match(r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text):
return Target(match.group(1))
else:
raise cls.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]:
res = await self.client.get(
- "https://music.163.com/api/artist/albums/{}".format(target),
+ f"https://music.163.com/api/artist/albums/{target}",
headers={"Referer": "https://music.163.com/"},
)
res_data = res.json()
@@ -74,13 +68,10 @@ class NcmArtist(NewMessage):
target_name = raw_post["artist"]["name"]
pics = [raw_post["picUrl"]]
url = "https://music.163.com/#/album?id={}".format(raw_post["id"])
- return Post(
- "ncm-artist", text=text, url=url, pics=pics, target_name=target_name
- )
+ return Post("ncm-artist", text=text, url=url, pics=pics, target_name=target_name)
class NcmRadio(NewMessage):
-
categories = {}
platform_name = "ncm-radio"
enable_tag = False
@@ -92,9 +83,7 @@ class NcmRadio(NewMessage):
parse_target_promot = "请输入主播电台主页(包含数字ID)的链接"
@classmethod
- async def get_target_name(
- cls, client: AsyncClient, target: Target
- ) -> Optional[str]:
+ async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
res = await client.post(
"http://music.163.com/api/dj/program/byradio",
headers={"Referer": "https://music.163.com/"},
@@ -109,9 +98,7 @@ class NcmRadio(NewMessage):
async def parse_target(cls, target_text: str) -> Target:
if re.match(r"^\d+$", target_text):
return Target(target_text)
- elif match := re.match(
- r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text
- ):
+ elif match := re.match(r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text):
return Target(match.group(1))
else:
raise cls.ParseTargetException()
diff --git a/nonebot_bison/platform/platform.py b/nonebot_bison/platform/platform.py
index 283bf44..9ec9073 100644
--- a/nonebot_bison/platform/platform.py
+++ b/nonebot_bison/platform/platform.py
@@ -1,29 +1,32 @@
-import json
import ssl
+import json
import time
import typing
+from typing import Any
+from dataclasses import dataclass
from abc import ABC, abstractmethod
from collections import defaultdict
-from dataclasses import dataclass
-from typing import Any, Collection, Optional, Type
+from collections.abc import Collection
import httpx
from httpx import AsyncClient
from nonebot.log import logger
from nonebot_plugin_saa import PlatformTarget
-from ..plugin_config import plugin_config
from ..post import Post
-from ..types import Category, RawPost, Tag, Target, UserSubInfo
+from ..plugin_config import plugin_config
from ..utils import ProcessContext, SchedulerConfig
+from ..types import Tag, Target, RawPost, Category, UserSubInfo
class CategoryNotSupport(Exception):
- "raise in get_category, when you know the category of the post but don't want to support it or don't support its parsing yet"
+ """raise in get_category, when you know the category of the post
+ but don't want to support it or don't support its parsing yet
+ """
class CategoryNotRecognize(Exception):
- "raise in get_category, when you don't know the category of post"
+ """raise in get_category, when you don't know the category of post"""
class RegistryMeta(type):
@@ -42,7 +45,6 @@ class RegistryMeta(type):
class PlatformMeta(RegistryMeta):
-
categories: dict[Category, str]
store: dict[Target, Any]
@@ -60,8 +62,7 @@ class PlatformABCMeta(PlatformMeta, ABC):
class Platform(metaclass=PlatformABCMeta, base=True):
-
- scheduler: Type[SchedulerConfig]
+ scheduler: type[SchedulerConfig]
ctx: ProcessContext
is_common: bool
enabled: bool
@@ -70,16 +71,14 @@ class Platform(metaclass=PlatformABCMeta, base=True):
categories: dict[Category, str]
enable_tag: bool
platform_name: str
- parse_target_promot: Optional[str] = None
- registry: list[Type["Platform"]]
+ parse_target_promot: str | None = None
+ registry: list[type["Platform"]]
client: AsyncClient
reverse_category: dict[str, Category]
@classmethod
@abstractmethod
- async def get_target_name(
- cls, client: AsyncClient, target: Target
- ) -> Optional[str]:
+ async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
...
@abstractmethod
@@ -95,11 +94,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
return await self.fetch_new_post(target, users)
except httpx.RequestError as err:
if plugin_config.bison_show_network_warning:
- logger.warning(
- "network connection error: {}, url: {}".format(
- type(err), err.request.url
- )
- )
+ logger.warning(f"network connection error: {type(err)}, url: {err.request.url}")
return []
except ssl.SSLError as err:
if plugin_config.bison_show_network_warning:
@@ -130,7 +125,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
return Target(target_string)
@abstractmethod
- def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]:
+ def get_tags(self, raw_post: RawPost) -> Collection[Tag] | None:
"Return Tag list of given RawPost"
@classmethod
@@ -201,9 +196,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
) -> list[tuple[PlatformTarget, list[Post]]]:
res: list[tuple[PlatformTarget, list[Post]]] = []
for user, cats, required_tags in users:
- user_raw_post = await self.filter_user_custom(
- new_posts, cats, required_tags
- )
+ user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
user_post: list[Post] = []
for raw_post in user_raw_post:
user_post.append(await self.do_parse(raw_post))
@@ -211,7 +204,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
return res
@abstractmethod
- def get_category(self, post: RawPost) -> Optional[Category]:
+ def get_category(self, post: RawPost) -> Category | None:
"Return category of given Rawpost"
raise NotImplementedError()
@@ -221,7 +214,7 @@ class MessageProcess(Platform, abstract=True):
def __init__(self, ctx: ProcessContext, client: AsyncClient):
super().__init__(ctx, client)
- self.parse_cache: dict[Any, Post] = dict()
+ self.parse_cache: dict[Any, Post] = {}
@abstractmethod
def get_id(self, post: RawPost) -> Any:
@@ -246,7 +239,7 @@ class MessageProcess(Platform, abstract=True):
"Get post list of the given target"
@abstractmethod
- def get_date(self, post: RawPost) -> Optional[int]:
+ def get_date(self, post: RawPost) -> int | None:
"Get post timestamp and return, return None if can't get the time"
async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]:
@@ -286,9 +279,7 @@ class NewMessage(MessageProcess, abstract=True):
inited: bool
exists_posts: set[Any]
- async def filter_common_with_diff(
- self, target: Target, raw_post_list: list[RawPost]
- ) -> list[RawPost]:
+ async def filter_common_with_diff(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]:
filtered_post = await self.filter_common(raw_post_list)
store = self.get_stored_data(target) or self.MessageStorage(False, set())
res = []
@@ -297,11 +288,7 @@ class NewMessage(MessageProcess, abstract=True):
for raw_post in filtered_post:
post_id = self.get_id(raw_post)
store.exists_posts.add(post_id)
- logger.info(
- "init {}-{} with {}".format(
- self.platform_name, target, store.exists_posts
- )
- )
+ logger.info(f"init {self.platform_name}-{target} with {store.exists_posts}")
store.inited = True
else:
for raw_post in filtered_post:
@@ -400,12 +387,11 @@ class SimplePost(MessageProcess, abstract=True):
return res
-def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
-
+def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]:
if typing.TYPE_CHECKING:
class NoTargetGroup(Platform, abstract=True):
- platform_list: list[Type[Platform]]
+ platform_list: list[type[Platform]]
platform_obj_list: list[Platform]
DUMMY_STR = "_DUMMY"
@@ -418,24 +404,18 @@ def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
for platform in platform_list:
if platform.has_target:
- raise RuntimeError(
- "Platform {} should have no target".format(platform.name)
- )
+ raise RuntimeError(f"Platform {platform.name} should have no target")
if name == DUMMY_STR:
name = platform.name
elif name != platform.name:
- raise RuntimeError("Platform name for {} not fit".format(platform_name))
+ raise RuntimeError(f"Platform name for {platform_name} not fit")
platform_category_key_set = set(platform.categories.keys())
if platform_category_key_set & categories_keys:
- raise RuntimeError(
- "Platform categories for {} duplicate".format(platform_name)
- )
+ raise RuntimeError(f"Platform categories for {platform_name} duplicate")
categories_keys |= platform_category_key_set
categories.update(platform.categories)
if platform.scheduler != scheduler:
- raise RuntimeError(
- "Platform scheduler for {} not fit".format(platform_name)
- )
+ raise RuntimeError(f"Platform scheduler for {platform_name} not fit")
def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient):
Platform.__init__(self, ctx, client)
@@ -444,15 +424,13 @@ def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
self.platform_obj_list.append(platform_class(ctx, client))
def __str__(self: "NoTargetGroup") -> str:
- return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]"
+ return "[" + " ".join(x.name for x in self.platform_list) + "]"
@classmethod
async def get_target_name(cls, client: AsyncClient, target: Target):
return await platform_list[0].get_target_name(client, target)
- async def fetch_new_post(
- self: "NoTargetGroup", target: Target, users: list[UserSubInfo]
- ):
+ async def fetch_new_post(self: "NoTargetGroup", target: Target, users: list[UserSubInfo]):
res = defaultdict(list)
for platform in self.platform_obj_list:
platform_res = await platform.fetch_new_post(target=target, users=users)
diff --git a/nonebot_bison/platform/rss.py b/nonebot_bison/platform/rss.py
index cbcddd4..94c584c 100644
--- a/nonebot_bison/platform/rss.py
+++ b/nonebot_bison/platform/rss.py
@@ -1,10 +1,10 @@
import calendar
import time
-from typing import Any, Optional
+from typing import Any
import feedparser
-from bs4 import BeautifulSoup as bs
from httpx import AsyncClient
+from bs4 import BeautifulSoup as bs
from ..post import Post
from ..types import RawPost, Target
@@ -20,7 +20,6 @@ class RssSchedConf(SchedulerConfig):
class Rss(NewMessage):
-
categories = {}
enable_tag = False
platform_name = "rss"
@@ -31,9 +30,7 @@ class Rss(NewMessage):
has_target = True
@classmethod
- async def get_target_name(
- cls, client: AsyncClient, target: Target
- ) -> Optional[str]:
+ async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
res = await client.get(target, timeout=10.0)
feed = feedparser.parse(res.text)
return feed["feed"]["title"]
@@ -69,7 +66,7 @@ class Rss(NewMessage):
else:
text = f"{title}\n\n{desc}"
- pics = list(map(lambda x: x.attrs["src"], soup("img")))
+ pics = [x.attrs["src"] for x in soup("img")]
if raw_post.get("media_content"):
for media in raw_post["media_content"]:
if media.get("medium") == "image" and media.get("url"):
diff --git a/nonebot_bison/platform/weibo.py b/nonebot_bison/platform/weibo.py
index 2fccea2..e8d4197 100644
--- a/nonebot_bison/platform/weibo.py
+++ b/nonebot_bison/platform/weibo.py
@@ -1,17 +1,16 @@
-import json
import re
-from collections.abc import Callable
+import json
+from typing import Any
from datetime import datetime
-from typing import Any, Optional
-from bs4 import BeautifulSoup as bs
from httpx import AsyncClient
from nonebot.log import logger
+from bs4 import BeautifulSoup as bs
from ..post import Post
-from ..types import *
-from ..utils import SchedulerConfig, http_client
from .platform import NewMessage
+from ..utils import SchedulerConfig, http_client
+from ..types import Tag, Target, RawPost, ApiError, Category
class WeiboSchedConf(SchedulerConfig):
@@ -21,7 +20,6 @@ class WeiboSchedConf(SchedulerConfig):
class Weibo(NewMessage):
-
categories = {
1: "转发",
2: "视频",
@@ -38,13 +36,9 @@ class Weibo(NewMessage):
parse_target_promot = "请输入用户主页(包含数字UID)的链接"
@classmethod
- async def get_target_name(
- cls, client: AsyncClient, target: Target
- ) -> Optional[str]:
+ async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
param = {"containerid": "100505" + target}
- res = await client.get(
- "https://m.weibo.cn/api/container/getIndex", params=param
- )
+ res = await client.get("https://m.weibo.cn/api/container/getIndex", params=param)
res_dict = json.loads(res.text)
if res_dict.get("ok") == 1:
return res_dict["data"]["userInfo"]["screen_name"]
@@ -63,13 +57,14 @@ class Weibo(NewMessage):
async def get_sub_list(self, target: Target) -> list[RawPost]:
params = {"containerid": "107603" + target}
- res = await self.client.get(
- "https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0
- )
+ res = await self.client.get("https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0)
res_data = json.loads(res.text)
if not res_data["ok"] and res_data["msg"] != "这里还没有内容":
raise ApiError(res.request.url)
- custom_filter: Callable[[RawPost], bool] = lambda d: d["card_type"] == 9
+
+ def custom_filter(d: RawPost) -> bool:
+ return d["card_type"] == 9
+
return list(filter(custom_filter, res_data["data"]["cards"]))
def get_id(self, post: RawPost) -> Any:
@@ -79,44 +74,32 @@ class Weibo(NewMessage):
return raw_post["card_type"] == 9
def get_date(self, raw_post: RawPost) -> float:
- created_time = datetime.strptime(
- raw_post["mblog"]["created_at"], "%a %b %d %H:%M:%S %z %Y"
- )
+ created_time = datetime.strptime(raw_post["mblog"]["created_at"], "%a %b %d %H:%M:%S %z %Y")
return created_time.timestamp()
- def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]:
+ def get_tags(self, raw_post: RawPost) -> list[Tag] | None:
"Return Tag list of given RawPost"
text = raw_post["mblog"]["text"]
soup = bs(text, "html.parser")
- res = list(
- map(
- lambda x: x[1:-1],
- filter(
- lambda s: s[0] == "#" and s[-1] == "#",
- map(lambda x: x.text, soup.find_all("span", class_="surl-text")),
- ),
+ res = [
+ x[1:-1]
+ for x in filter(
+ lambda s: s[0] == "#" and s[-1] == "#",
+ (x.text for x in soup.find_all("span", class_="surl-text")),
)
- )
- super_topic_img = soup.find(
- "img", src=re.compile(r"timeline_card_small_super_default")
- )
+ ]
+ super_topic_img = soup.find("img", src=re.compile(r"timeline_card_small_super_default"))
if super_topic_img:
try:
- res.append(
- super_topic_img.parent.parent.find("span", class_="surl-text").text # type: ignore
- + "超话"
- )
- except:
- logger.info("super_topic extract error: {}".format(text))
+ res.append(super_topic_img.parent.parent.find("span", class_="surl-text").text + "超话") # type: ignore
+ except Exception:
+ logger.info(f"super_topic extract error: {text}")
return res
def get_category(self, raw_post: RawPost) -> Category:
if raw_post["mblog"].get("retweeted_status"):
return Category(1)
- elif (
- raw_post["mblog"].get("page_info")
- and raw_post["mblog"]["page_info"].get("type") == "video"
- ):
+ elif raw_post["mblog"].get("page_info") and raw_post["mblog"]["page_info"].get("type") == "video":
return Category(2)
elif raw_post["mblog"].get("pics"):
return Category(3)
@@ -129,7 +112,8 @@ class Weibo(NewMessage):
async def parse(self, raw_post: RawPost) -> Post:
header = {
- "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
+ "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,"
+ "*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
"accept-language": "zh-CN,zh;q=0.9",
"authority": "m.weibo.cn",
"cache-control": "max-age=0",
@@ -147,26 +131,16 @@ class Weibo(NewMessage):
retweeted = True
pic_num = info["retweeted_status"]["pic_num"] if retweeted else info["pic_num"]
if info["isLongText"] or pic_num > 9:
- res = await self.client.get(
- "https://m.weibo.cn/detail/{}".format(info["mid"]), headers=header
- )
+ res = await self.client.get(f"https://m.weibo.cn/detail/{info['mid']}", headers=header)
try:
match = re.search(r'"status": ([\s\S]+),\s+"call"', res.text)
assert match
full_json_text = match.group(1)
info = json.loads(full_json_text)
- except:
- logger.info(
- "detail message error: https://m.weibo.cn/detail/{}".format(
- info["mid"]
- )
- )
+ except Exception:
+ logger.info(f"detail message error: https://m.weibo.cn/detail/{info['mid']}")
parsed_text = self._get_text(info["text"])
- raw_pics_list = (
- info["retweeted_status"].get("pics", [])
- if retweeted
- else info.get("pics", [])
- )
+ raw_pics_list = info["retweeted_status"].get("pics", []) if retweeted else info.get("pics", [])
pic_urls = [img["large"]["url"] for img in raw_pics_list]
pics = []
for pic_url in pic_urls:
@@ -174,7 +148,7 @@ class Weibo(NewMessage):
res = await client.get(pic_url)
res.raise_for_status()
pics.append(res.content)
- detail_url = "https://weibo.com/{}/{}".format(info["user"]["id"], info["bid"])
+ detail_url = f"https://weibo.com/{info['user']['id']}/{info['bid']}"
# return parsed_text, detail_url, pic_urls
return Post(
"weibo",
diff --git a/nonebot_bison/plugin_config.py b/nonebot_bison/plugin_config.py
index cb35ed0..041b38a 100644
--- a/nonebot_bison/plugin_config.py
+++ b/nonebot_bison/plugin_config.py
@@ -1,11 +1,8 @@
-from typing import Optional
-
import nonebot
from pydantic import BaseSettings
class PlugConfig(BaseSettings):
-
bison_config_path: str = ""
bison_use_pic: bool = False
bison_init_filter: bool = True
@@ -17,8 +14,12 @@ class PlugConfig(BaseSettings):
bison_use_pic_merge: int = 0 # 多图片时启用图片合并转发(仅限群)
# 0:不启用;1:首条消息单独发送,剩余照片合并转发;2以及以上:所有消息全部合并转发
bison_resend_times: int = 0
- bison_proxy: Optional[str]
- bison_ua: str = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.103 Safari/537.36"
+ bison_proxy: str | None
+ bison_ua: str = (
+ "Mozilla/5.0 (X11; Linux x86_64) "
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
+ "Chrome/51.0.2704.103 Safari/537.36"
+ )
bison_show_network_warning: bool = True
class Config:
diff --git a/nonebot_bison/post/__init__.py b/nonebot_bison/post/__init__.py
index ff93bec..3900f47 100644
--- a/nonebot_bison/post/__init__.py
+++ b/nonebot_bison/post/__init__.py
@@ -1,3 +1,3 @@
from .post import Post
-__all__ = ["Post", "CustomPost"]
+__all__ = ["Post"]
diff --git a/nonebot_bison/post/abstract_post.py b/nonebot_bison/post/abstract_post.py
index ee055d5..a8d88de 100644
--- a/nonebot_bison/post/abstract_post.py
+++ b/nonebot_bison/post/abstract_post.py
@@ -1,7 +1,6 @@
-from abc import abstractmethod
-from dataclasses import dataclass, field
from functools import reduce
-from typing import Optional
+from abc import abstractmethod
+from dataclasses import field, dataclass
from nonebot_plugin_saa import MessageFactory, MessageSegmentFactory
@@ -25,12 +24,12 @@ class BasePost:
class OptionalMixin:
# Because of https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
- override_use_pic: Optional[bool] = None
+ override_use_pic: bool | None = None
compress: bool = False
extra_msg: list[MessageFactory] = field(default_factory=list)
def _use_pic(self):
- if not self.override_use_pic is None:
+ if self.override_use_pic is not None:
return self.override_use_pic
return plugin_config.bison_use_pic
@@ -44,13 +43,9 @@ class AbstractPost(OptionalMixin, BasePost):
msg_segments = await self.generate_text_messages()
if msg_segments:
if self.compress:
- msgs = [
- reduce(lambda x, y: x.append(y), msg_segments, MessageFactory([]))
- ]
+ msgs = [reduce(lambda x, y: x.append(y), msg_segments, MessageFactory([]))]
else:
- msgs = list(
- map(lambda msg_segment: MessageFactory([msg_segment]), msg_segments)
- )
+ msgs = [MessageFactory([msg_segment]) for msg_segment in msg_segments]
else:
msgs = []
msgs.extend(self.extra_msg)
diff --git a/nonebot_bison/post/custom_post.py b/nonebot_bison/post/custom_post.py
index 951a9c8..4921bc2 100644
--- a/nonebot_bison/post/custom_post.py
+++ b/nonebot_bison/post/custom_post.py
@@ -1,19 +1,17 @@
-from dataclasses import dataclass, field
-from typing import Optional
+from dataclasses import field, dataclass
-from nonebot.adapters.onebot.v11 import MessageSegment
from nonebot.log import logger
from nonebot.plugin import require
-from nonebot_plugin_saa import Image, MessageFactory, MessageSegmentFactory, Text
+from nonebot.adapters.onebot.v11 import MessageSegment
+from nonebot_plugin_saa import Text, Image, MessageSegmentFactory
-from .abstract_post import AbstractPost, BasePost
+from .abstract_post import BasePost, AbstractPost
@dataclass
class _CustomPost(BasePost):
-
ms_factories: list[MessageSegmentFactory] = field(default_factory=list)
- css_path: Optional[str] = None # 模板文件所用css路径
+ css_path: str | None = None # 模板文件所用css路径
async def generate_text_messages(self) -> list[MessageSegmentFactory]:
return self.ms_factories
@@ -31,15 +29,13 @@ class _CustomPost(BasePost):
for message_segment in self.ms_factories:
match message_segment:
case Text(data={"text": text}):
- md += "{}
".format(text)
+ md += f"{text}
"
case Image(data={"image": image}):
# use onebot v11 to convert image into url
ob11_image = MessageSegment.image(image)
md += "\n".format(ob11_image.data["file"])
case _:
- logger.warning(
- "custom_post不支持处理类型:{}".format(type(message_segment))
- )
+ logger.warning(f"custom_post不支持处理类型:{type(message_segment)}")
continue
return md
diff --git a/nonebot_bison/post/post.py b/nonebot_bison/post/post.py
index afec2ce..c64096e 100644
--- a/nonebot_bison/post/post.py
+++ b/nonebot_bison/post/post.py
@@ -1,30 +1,27 @@
-from dataclasses import dataclass, field
-from functools import reduce
from io import BytesIO
-from typing import Optional, Union
+from dataclasses import field, dataclass
-import nonebot_plugin_saa as saa
-from nonebot.log import logger
-from nonebot_plugin_saa.utils import MessageFactory, MessageSegmentFactory
from PIL import Image
+from nonebot.log import logger
+import nonebot_plugin_saa as saa
+from nonebot_plugin_saa.utils import MessageSegmentFactory
-from ..utils import http_client, parse_text
-from .abstract_post import AbstractPost, BasePost, OptionalMixin
+from ..utils import parse_text, http_client
+from .abstract_post import BasePost, AbstractPost
@dataclass
class _Post(BasePost):
-
target_type: str
text: str
- url: Optional[str] = None
- target_name: Optional[str] = None
- pics: list[Union[str, bytes]] = field(default_factory=list)
+ url: str | None = None
+ target_name: str | None = None
+ pics: list[str | bytes] = field(default_factory=list)
- _message: Optional[list[MessageSegmentFactory]] = None
- _pic_message: Optional[list[MessageSegmentFactory]] = None
+ _message: list[MessageSegmentFactory] | None = None
+ _pic_message: list[MessageSegmentFactory] | None = None
- async def _pic_url_to_image(self, data: Union[str, bytes]) -> Image.Image:
+ async def _pic_url_to_image(self, data: str | bytes) -> Image.Image:
pic_buffer = BytesIO()
if isinstance(data, str):
async with http_client() as client:
@@ -101,22 +98,19 @@ class _Post(BasePost):
self.pics.insert(0, target_io.getvalue())
async def generate_text_messages(self) -> list[MessageSegmentFactory]:
-
if self._message is None:
await self._pic_merge()
msg_segments: list[MessageSegmentFactory] = []
text = ""
if self.text:
- text += "{}".format(
- self.text if len(self.text) < 500 else self.text[:500] + "..."
- )
+ text += "{}".format(self.text if len(self.text) < 500 else self.text[:500] + "...")
if text:
text += "\n"
- text += "来源: {}".format(self.target_type)
+ text += f"来源: {self.target_type}"
if self.target_name:
- text += " {}".format(self.target_name)
+ text += f" {self.target_name}"
if self.url:
- text += " \n详情: {}".format(self.url)
+ text += f" \n详情: {self.url}"
msg_segments.append(saa.Text(text))
for pic in self.pics:
msg_segments.append(saa.Image(pic))
@@ -124,17 +118,16 @@ class _Post(BasePost):
return self._message
async def generate_pic_messages(self) -> list[MessageSegmentFactory]:
-
if self._pic_message is None:
await self._pic_merge()
msg_segments: list[MessageSegmentFactory] = []
text = ""
if self.text:
- text += "{}".format(self.text)
+ text += f"{self.text}"
text += "\n"
- text += "来源: {}".format(self.target_type)
+ text += f"来源: {self.target_type}"
if self.target_name:
- text += " {}".format(self.target_name)
+ text += f" {self.target_name}"
msg_segments.append(await parse_text(text))
if not self.target_type == "rss" and self.url:
msg_segments.append(saa.Text(self.url))
@@ -149,14 +142,7 @@ class _Post(BasePost):
self.target_name,
self.text if len(self.text) < 500 else self.text[:500] + "...",
self.url,
- ", ".join(
- map(
- lambda x: "b64img"
- if isinstance(x, bytes) or x.startswith("base64")
- else x,
- self.pics,
- )
- ),
+ ", ".join("b64img" if isinstance(x, bytes) or x.startswith("base64") else x for x in self.pics),
)
diff --git a/nonebot_bison/scheduler/__init__.py b/nonebot_bison/scheduler/__init__.py
index 4fe6284..19c9284 100644
--- a/nonebot_bison/scheduler/__init__.py
+++ b/nonebot_bison/scheduler/__init__.py
@@ -1 +1,3 @@
-from .manager import *
+from .manager import init_scheduler, scheduler_dict, handle_delete_target, handle_insert_new_target
+
+__all__ = ["init_scheduler", "handle_delete_target", "handle_insert_new_target", "scheduler_dict"]
diff --git a/nonebot_bison/scheduler/manager.py b/nonebot_bison/scheduler/manager.py
index 9c48676..271f1b5 100644
--- a/nonebot_bison/scheduler/manager.py
+++ b/nonebot_bison/scheduler/manager.py
@@ -1,18 +1,16 @@
-from typing import Type
-
from ..config import config
-from ..config.db_model import Target
-from ..platform import platform_manager
-from ..types import Target as T_Target
-from ..utils import SchedulerConfig
from .scheduler import Scheduler
+from ..utils import SchedulerConfig
+from ..config.db_model import Target
+from ..types import Target as T_Target
+from ..platform import platform_manager
-scheduler_dict: dict[Type[SchedulerConfig], Scheduler] = {}
+scheduler_dict: dict[type[SchedulerConfig], Scheduler] = {}
async def init_scheduler():
- _schedule_class_dict: dict[Type[SchedulerConfig], list[Target]] = {}
- _schedule_class_platform_dict: dict[Type[SchedulerConfig], list[str]] = {}
+ _schedule_class_dict: dict[type[SchedulerConfig], list[Target]] = {}
+ _schedule_class_platform_dict: dict[type[SchedulerConfig], list[str]] = {}
for platform in platform_manager.values():
scheduler_config = platform.scheduler
if not hasattr(scheduler_config, "name") or not scheduler_config.name:
@@ -33,9 +31,7 @@ async def init_scheduler():
for target in target_list:
schedulable_args.append((target.platform_name, T_Target(target.target)))
platform_name_list = _schedule_class_platform_dict[scheduler_config]
- scheduler_dict[scheduler_config] = Scheduler(
- scheduler_config, schedulable_args, platform_name_list
- )
+ scheduler_dict[scheduler_config] = Scheduler(scheduler_config, schedulable_args, platform_name_list)
config.register_add_target_hook(handle_insert_new_target)
config.register_delete_target_hook(handle_delete_target)
diff --git a/nonebot_bison/scheduler/scheduler.py b/nonebot_bison/scheduler/scheduler.py
index 4f26b69..2d1d606 100644
--- a/nonebot_bison/scheduler/scheduler.py
+++ b/nonebot_bison/scheduler/scheduler.py
@@ -1,14 +1,13 @@
from dataclasses import dataclass
-from typing import Optional, Type
from nonebot.log import logger
from nonebot_plugin_apscheduler import scheduler
from nonebot_plugin_saa.utils.exceptions import NoBotFound
-from ..config import config
-from ..platform import platform_manager
-from ..send import send_msgs
from ..types import Target
+from ..config import config
+from ..send import send_msgs
+from ..platform import platform_manager
from ..utils import ProcessContext, SchedulerConfig
@@ -20,12 +19,11 @@ class Schedulable:
class Scheduler:
-
schedulable_list: list[Schedulable]
def __init__(
self,
- scheduler_config: Type[SchedulerConfig],
+ scheduler_config: type[SchedulerConfig],
schedulables: list[tuple[str, Target]],
platform_name_list: list[str],
):
@@ -37,15 +35,12 @@ class Scheduler:
self.scheduler_config_obj = self.scheduler_config()
self.schedulable_list = []
for platform_name, target in schedulables:
- self.schedulable_list.append(
- Schedulable(
- platform_name=platform_name, target=target, current_weight=0
- )
- )
+ self.schedulable_list.append(Schedulable(platform_name=platform_name, target=target, current_weight=0))
self.platform_name_list = platform_name_list
self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值
logger.info(
- f"register scheduler for {self.name} with {self.scheduler_config.schedule_type} {self.scheduler_config.schedule_setting}"
+ f"register scheduler for {self.name} with "
+ f"{self.scheduler_config.schedule_type} {self.scheduler_config.schedule_setting}"
)
scheduler.add_job(
self.exec_fetch,
@@ -53,7 +48,7 @@ class Scheduler:
**self.scheduler_config.schedule_setting,
)
- async def get_next_schedulable(self) -> Optional[Schedulable]:
+ async def get_next_schedulable(self) -> Schedulable | None:
if not self.schedulable_list:
return None
cur_weight = await config.get_current_weight_val(self.platform_name_list)
@@ -61,16 +56,9 @@ class Scheduler:
self.pre_weight_val = 0
cur_max_schedulable = None
for schedulable in self.schedulable_list:
- schedulable.current_weight += cur_weight[
- f"{schedulable.platform_name}-{schedulable.target}"
- ]
- weight_sum += cur_weight[
- f"{schedulable.platform_name}-{schedulable.target}"
- ]
- if (
- not cur_max_schedulable
- or cur_max_schedulable.current_weight < schedulable.current_weight
- ):
+ schedulable.current_weight += cur_weight[f"{schedulable.platform_name}-{schedulable.target}"]
+ weight_sum += cur_weight[f"{schedulable.platform_name}-{schedulable.target}"]
+ if not cur_max_schedulable or cur_max_schedulable.current_weight < schedulable.current_weight:
cur_max_schedulable = schedulable
assert cur_max_schedulable
cur_max_schedulable.current_weight -= weight_sum
@@ -80,9 +68,7 @@ class Scheduler:
context = ProcessContext()
if not (schedulable := await self.get_next_schedulable()):
return
- logger.trace(
- f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}"
- )
+ logger.trace(f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}")
send_userinfo_list = await config.get_platform_target_subscribers(
schedulable.platform_name, schedulable.target
)
@@ -92,9 +78,7 @@ class Scheduler:
try:
platform_obj = platform_manager[schedulable.platform_name](context, client)
- to_send = await platform_obj.do_fetch_new_post(
- schedulable.target, send_userinfo_list
- )
+ to_send = await platform_obj.do_fetch_new_post(schedulable.target, send_userinfo_list)
except Exception as err:
records = context.gen_req_records()
for record in records:
@@ -107,7 +91,7 @@ class Scheduler:
for user, send_list in to_send:
for send_post in send_list:
- logger.info("send to {}: {}".format(user, send_post))
+ logger.info(f"send to {user}: {send_post}")
try:
await send_msgs(
user,
@@ -119,19 +103,14 @@ class Scheduler:
def insert_new_schedulable(self, platform_name: str, target: Target):
self.pre_weight_val += 1000
self.schedulable_list.append(Schedulable(platform_name, target, 1000))
- logger.info(
- f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})"
- )
+ logger.info(f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})")
def delete_schedulable(self, platform_name, target: Target):
if not self.schedulable_list:
return
to_find_idx = None
for idx, schedulable in enumerate(self.schedulable_list):
- if (
- schedulable.platform_name == platform_name
- and schedulable.target == target
- ):
+ if schedulable.platform_name == platform_name and schedulable.target == target:
to_find_idx = idx
break
if to_find_idx is not None:
diff --git a/nonebot_bison/script/cli.py b/nonebot_bison/script/cli.py
index 48883c8..1f416c5 100644
--- a/nonebot_bison/script/cli.py
+++ b/nonebot_bison/script/cli.py
@@ -1,21 +1,23 @@
-import importlib
import json
import time
-from functools import partial, wraps
+import importlib
from pathlib import Path
from types import ModuleType
-from typing import Any, Callable, Coroutine, TypeVar
+from typing import Any, TypeVar
+from functools import wraps, partial
+from collections.abc import Callable, Coroutine
from nonebot.log import logger
-from ..config.subs_io import subscribes_export, subscribes_import
-from ..config.subs_io.nbesf_model import v1, v2
from ..scheduler.manager import init_scheduler
+from ..config.subs_io.nbesf_model import v1, v2
+from ..config.subs_io import subscribes_export, subscribes_import
try:
+ from typing_extensions import ParamSpec
+
import anyio
import click
- from typing_extensions import ParamSpec
except ImportError as e: # pragma: no cover
raise ImportError("请使用 `pip install nonebot-bison[cli]` 安装所需依赖") from e
@@ -65,9 +67,7 @@ def path_init(ctx, param, value):
@cli.command(help="导出Nonebot Bison Exchangable Subcribes File", name="export")
-@click.option(
- "--path", "-p", default=None, callback=path_init, help="导出路径, 如果不指定,则默认为工作目录"
-)
+@click.option("--path", "-p", default=None, callback=path_init, help="导出路径, 如果不指定,则默认为工作目录")
@click.option(
"--format",
default="json",
@@ -76,7 +76,6 @@ def path_init(ctx, param, value):
)
@run_async
async def subs_export(path: Path, format: str):
-
await init_scheduler()
export_file = path / f"bison_subscribes_export_{int(time.time())}.{format}"
@@ -121,7 +120,6 @@ async def subs_export(path: Path, format: str):
)
@run_async
async def subs_import(path: str, format: str):
-
await init_scheduler()
import_file_path = Path(path)
diff --git a/nonebot_bison/send.py b/nonebot_bison/send.py
index 540eba5..62d1af6 100644
--- a/nonebot_bison/send.py
+++ b/nonebot_bison/send.py
@@ -1,17 +1,16 @@
import asyncio
from collections import deque
-from typing import Deque
-from nonebot.adapters.onebot.v11.exception import ActionFailed
from nonebot.log import logger
-from nonebot_plugin_saa import AggregatedMessageFactory, MessageFactory, PlatformTarget
+from nonebot.adapters.onebot.v11.exception import ActionFailed
from nonebot_plugin_saa.utils.auto_select_bot import refresh_bots
+from nonebot_plugin_saa import MessageFactory, PlatformTarget, AggregatedMessageFactory
from .plugin_config import plugin_config
Sendable = MessageFactory | AggregatedMessageFactory
-QUEUE: Deque[tuple[PlatformTarget, Sendable, int]] = deque()
+QUEUE: deque[tuple[PlatformTarget, Sendable, int]] = deque()
MESSGE_SEND_INTERVAL = 1.5
diff --git a/nonebot_bison/sub_manager/add_sub.py b/nonebot_bison/sub_manager/add_sub.py
index f1539b5..af676d8 100644
--- a/nonebot_bison/sub_manager/add_sub.py
+++ b/nonebot_bison/sub_manager/add_sub.py
@@ -1,21 +1,20 @@
import contextlib
-from typing import Type, cast
-from nonebot.adapters import Message, MessageTemplate
+from nonebot.typing import T_State
from nonebot.matcher import Matcher
from nonebot.params import Arg, ArgPlainText
-from nonebot.typing import T_State
-from nonebot_plugin_saa import PlatformTarget, SupportedAdapters, Text
+from nonebot.adapters import Message, MessageTemplate
+from nonebot_plugin_saa import Text, PlatformTarget, SupportedAdapters
-from ..apis import check_sub_target
-from ..config import config
-from ..config.db_config import SubscribeDupException
-from ..platform import Platform, platform_manager
from ..types import Target
+from ..config import config
+from ..apis import check_sub_target
+from ..platform import Platform, platform_manager
+from ..config.db_config import SubscribeDupException
from .utils import common_platform, ensure_user_info, gen_handle_cancel
-def do_add_sub(add_sub: Type[Matcher]):
+def do_add_sub(add_sub: type[Matcher]):
handle_cancel = gen_handle_cancel(add_sub, "已中止订阅")
add_sub.handle()(ensure_user_info(add_sub))
@@ -25,12 +24,7 @@ def do_add_sub(add_sub: Type[Matcher]):
state["_prompt"] = (
"请输入想要订阅的平台,目前支持,请输入冒号左边的名称:\n"
+ "".join(
- [
- "{}:{}\n".format(
- platform_name, platform_manager[platform_name].name
- )
- for platform_name in common_platform
- ]
+ [f"{platform_name}: {platform_manager[platform_name].name}\n" for platform_name in common_platform]
)
+ "要查看全部平台请输入:“全部”\n中止订阅过程请输入:“取消”"
)
@@ -39,10 +33,7 @@ def do_add_sub(add_sub: Type[Matcher]):
async def parse_platform(state: T_State, platform: str = ArgPlainText()) -> None:
if platform == "全部":
message = "全部平台\n" + "\n".join(
- [
- "{}:{}".format(platform_name, platform.name)
- for platform_name, platform in platform_manager.items()
- ]
+ [f"{platform_name}: {platform.name}" for platform_name, platform in platform_manager.items()]
)
await add_sub.reject(message)
elif platform == "取消":
@@ -57,9 +48,7 @@ def do_add_sub(add_sub: Type[Matcher]):
cur_platform = platform_manager[state["platform"]]
if cur_platform.has_target:
state["_prompt"] = (
- ("1." + cur_platform.parse_target_promot + "\n2.")
- if cur_platform.parse_target_promot
- else ""
+ ("1." + cur_platform.parse_target_promot + "\n2.") if cur_platform.parse_target_promot else ""
) + "请输入订阅用户的id\n查询id获取方法请回复:“查询”"
else:
matcher.set_arg("raw_id", None) # type: ignore
@@ -81,9 +70,7 @@ def do_add_sub(add_sub: Type[Matcher]):
image = "https://s3.bmp.ovh/imgs/2022/03/ab3cc45d83bd3dd3.jpg"
msg.overwrite(
SupportedAdapters.onebot_v11,
- MessageSegment.share(
- url=url, title=title, content=content, image=image
- ),
+ MessageSegment.share(url=url, title=title, content=content, image=image),
)
await msg.reject()
platform = platform_manager[state["platform"]]
@@ -99,14 +86,12 @@ def do_add_sub(add_sub: Type[Matcher]):
await add_sub.reject("id输入错误")
state["id"] = raw_id_text
state["name"] = name
- except (Platform.ParseTargetException):
+ except Platform.ParseTargetException:
await add_sub.reject("不能从你的输入中提取出id,请检查你输入的内容是否符合预期")
else:
await add_sub.send(
- "即将订阅的用户为:{} {} {}\n如有错误请输入“取消”重新订阅".format(
- state["platform"], state["name"], state["id"]
- )
- )
+ f"即将订阅的用户为:{state['platform']} {state['name']} {state['id']}\n如有错误请输入“取消”重新订阅"
+ ) # noqa: E501
@add_sub.handle()
async def prepare_get_categories(matcher: Matcher, state: T_State):
@@ -125,7 +110,7 @@ def do_add_sub(add_sub: Type[Matcher]):
if platform_manager[state["platform"]].categories:
for cat in raw_cats_text.split():
if cat not in platform_manager[state["platform"]].reverse_category:
- await add_sub.reject("不支持 {}".format(cat))
+ await add_sub.reject(f"不支持 {cat}")
res.append(platform_manager[state["platform"]].reverse_category[cat])
state["cats"] = res
@@ -135,14 +120,16 @@ def do_add_sub(add_sub: Type[Matcher]):
matcher.set_arg("raw_tags", None) # type: ignore
state["tags"] = []
return
- state["_prompt"] = '请输入要订阅/屏蔽的标签(不含#号)\n多个标签请使用空格隔开\n订阅所有标签输入"全部标签"\n具体规则回复"详情"'
+ state["_prompt"] = "请输入要订阅/屏蔽的标签(不含#号)\n" "多个标签请使用空格隔开\n" '订阅所有标签输入"全部标签"\n' '具体规则回复"详情"' # noqa: E501
@add_sub.got("raw_tags", MessageTemplate("{_prompt}"), [handle_cancel])
async def parser_tags(state: T_State, raw_tags: Message = Arg()):
raw_tags_text = raw_tags.extract_plain_text()
if raw_tags_text == "详情":
await add_sub.reject(
- "订阅标签直接输入标签内容\n屏蔽标签请在标签名称前添加~号\n详见https://nonebot-bison.netlify.app/usage/#%E5%B9%B3%E5%8F%B0%E8%AE%A2%E9%98%85%E6%A0%87%E7%AD%BE-tag"
+ "订阅标签直接输入标签内容\n"
+ "屏蔽标签请在标签名称前添加~号\n"
+ "详见https://nonebot-bison.netlify.app/usage/#%E5%B9%B3%E5%8F%B0%E8%AE%A2%E9%98%85%E6%A0%87%E7%AD%BE-tag"
)
if raw_tags_text in ["全部标签", "全部", "全标签"]:
state["tags"] = []
@@ -150,9 +137,7 @@ def do_add_sub(add_sub: Type[Matcher]):
state["tags"] = raw_tags_text.split()
@add_sub.handle()
- async def add_sub_process(
- state: T_State, user: PlatformTarget = Arg("target_user_info")
- ):
+ async def add_sub_process(state: T_State, user: PlatformTarget = Arg("target_user_info")):
try:
await config.add_subscribe(
user=user,
diff --git a/nonebot_bison/sub_manager/del_sub.py b/nonebot_bison/sub_manager/del_sub.py
index c0e33d8..c1003c2 100644
--- a/nonebot_bison/sub_manager/del_sub.py
+++ b/nonebot_bison/sub_manager/del_sub.py
@@ -1,26 +1,22 @@
-from typing import Type
-
+from nonebot.typing import T_State
from nonebot.matcher import Matcher
from nonebot.params import Arg, EventPlainText
-from nonebot.typing import T_State
from nonebot_plugin_saa import MessageFactory, PlatformTarget
from ..config import config
-from ..platform import platform_manager
from ..types import Category
from ..utils import parse_text
+from ..platform import platform_manager
from .utils import ensure_user_info, gen_handle_cancel
-def do_del_sub(del_sub: Type[Matcher]):
+def do_del_sub(del_sub: type[Matcher]):
handle_cancel = gen_handle_cancel(del_sub, "删除中止")
del_sub.handle()(ensure_user_info(del_sub))
@del_sub.handle()
- async def send_list(
- state: T_State, user_info: PlatformTarget = Arg("target_user_info")
- ):
+ async def send_list(state: T_State, user_info: PlatformTarget = Arg("target_user_info")):
sub_list = await config.list_subscribe(user_info)
if not sub_list:
await del_sub.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅")
@@ -31,22 +27,10 @@ def do_del_sub(del_sub: Type[Matcher]):
"platform_name": sub.target.platform_name,
"target": sub.target.target,
}
- res += "{} {} {} {}\n".format(
- index,
- sub.target.platform_name,
- sub.target.target_name,
- sub.target.target,
- )
+ res += f"{index} {sub.target.platform_name} {sub.target.target_name} {sub.target.target}\n"
platform = platform_manager[sub.target.platform_name]
if platform.categories:
- res += " [{}]".format(
- ", ".join(
- map(
- lambda x: platform.categories[Category(x)],
- sub.categories,
- )
- )
- )
+ res += " [{}]".format(", ".join(platform.categories[Category(x)] for x in sub.categories))
if platform.enable_tag:
res += " {}".format(", ".join(sub.tags))
res += "\n"
@@ -62,7 +46,7 @@ def do_del_sub(del_sub: Type[Matcher]):
try:
index = int(index_str)
await config.del_subscribe(user_info, **state["sub_table"][index])
- except Exception as e:
+ except Exception:
await del_sub.reject("删除错误")
else:
await del_sub.finish("删除成功")
diff --git a/nonebot_bison/sub_manager/query_sub.py b/nonebot_bison/sub_manager/query_sub.py
index dd301f7..5bce812 100644
--- a/nonebot_bison/sub_manager/query_sub.py
+++ b/nonebot_bison/sub_manager/query_sub.py
@@ -1,17 +1,15 @@
-from typing import Type
-
-from nonebot.matcher import Matcher
from nonebot.params import Arg
+from nonebot.matcher import Matcher
from nonebot_plugin_saa import MessageFactory, PlatformTarget
from ..config import config
-from ..platform import platform_manager
from ..types import Category
from ..utils import parse_text
from .utils import ensure_user_info
+from ..platform import platform_manager
-def do_query_sub(query_sub: Type[Matcher]):
+def do_query_sub(query_sub: type[Matcher]):
query_sub.handle()(ensure_user_info(query_sub))
@query_sub.handle()
@@ -19,19 +17,10 @@ def do_query_sub(query_sub: Type[Matcher]):
sub_list = await config.list_subscribe(user_info)
res = "订阅的帐号为:\n"
for sub in sub_list:
- res += "{} {} {}".format(
- # sub["target_type"], sub["target_name"], sub["target"]
- sub.target.platform_name,
- sub.target.target_name,
- sub.target.target,
- )
+ res += f"{sub.target.platform_name} {sub.target.target_name} {sub.target.target}"
platform = platform_manager[sub.target.platform_name]
if platform.categories:
- res += " [{}]".format(
- ", ".join(
- map(lambda x: platform.categories[Category(x)], sub.categories)
- )
- )
+ res += " [{}]".format(", ".join(platform.categories[Category(x)] for x in sub.categories))
if platform.enable_tag:
res += " {}".format(", ".join(sub.tags))
res += "\n"
diff --git a/nonebot_bison/sub_manager/utils.py b/nonebot_bison/sub_manager/utils.py
index 7c41f08..d069496 100644
--- a/nonebot_bison/sub_manager/utils.py
+++ b/nonebot_bison/sub_manager/utils.py
@@ -1,13 +1,13 @@
import contextlib
-from typing import Annotated, Type
+from typing import Annotated
-from nonebot.adapters import Event
-from nonebot.matcher import Matcher
-from nonebot.params import Depends, EventPlainText, EventToMe
-from nonebot.permission import SUPERUSER
from nonebot.rule import Rule
+from nonebot.adapters import Event
from nonebot.typing import T_State
+from nonebot.matcher import Matcher
+from nonebot.permission import SUPERUSER
from nonebot_plugin_saa import extract_target
+from nonebot.params import Depends, EventToMe, EventPlainText
from ..platform import platform_manager
from ..plugin_config import plugin_config
@@ -31,7 +31,7 @@ common_platform = [
]
-def gen_handle_cancel(matcher: Type[Matcher], message: str):
+def gen_handle_cancel(matcher: type[Matcher], message: str):
async def _handle_cancel(text: Annotated[str, EventPlainText()]):
if text == "取消":
await matcher.finish(message)
@@ -39,12 +39,10 @@ def gen_handle_cancel(matcher: Type[Matcher], message: str):
return Depends(_handle_cancel)
-def ensure_user_info(matcher: Type[Matcher]):
+def ensure_user_info(matcher: type[Matcher]):
async def _check_user_info(state: T_State):
if not state.get("target_user_info"):
- await matcher.finish(
- "No target_user_info set, this shouldn't happen, please issue"
- )
+ await matcher.finish("No target_user_info set, this shouldn't happen, please issue")
return _check_user_info
diff --git a/nonebot_bison/utils/__init__.py b/nonebot_bison/utils/__init__.py
index 60aeacc..64d042d 100644
--- a/nonebot_bison/utils/__init__.py
+++ b/nonebot_bison/utils/__init__.py
@@ -1,17 +1,16 @@
import difflib
import re
import sys
-from typing import Union
import nonebot
-from bs4 import BeautifulSoup as bs
-from nonebot.log import default_format, logger
from nonebot.plugin import require
-from nonebot_plugin_saa import Image, MessageSegmentFactory, Text
+from bs4 import BeautifulSoup as bs
+from nonebot.log import logger, default_format
+from nonebot_plugin_saa import Text, Image, MessageSegmentFactory
-from ..plugin_config import plugin_config
-from .context import ProcessContext
from .http import http_client
+from .context import ProcessContext
+from ..plugin_config import plugin_config
from .scheduler_config import SchedulerConfig, scheduler
__all__ = [
@@ -30,7 +29,7 @@ class Singleton(type):
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
- cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
+ cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
@@ -63,7 +62,7 @@ def html_to_text(html: str, query_dict: dict = {}) -> str:
class Filter:
def __init__(self) -> None:
- self.level: Union[int, str] = "DEBUG"
+ self.level: int | str = "DEBUG"
def __call__(self, record):
module_name: str = record["name"]
@@ -71,9 +70,7 @@ class Filter:
if module:
module_name = getattr(module, "__module_name__", module_name)
record["name"] = module_name.split(".")[0]
- levelno = (
- logger.level(self.level).no if isinstance(self.level, str) else self.level
- )
+ levelno = logger.level(self.level).no if isinstance(self.level, str) else self.level
nonebot_warning_level = logger.level("WARNING").no
return (
record["level"].no >= levelno
@@ -94,11 +91,7 @@ if plugin_config.bison_filter_log:
)
config = nonebot.get_driver().config
logger.success("Muted info & success from nonebot")
- default_filter.level = (
- ("DEBUG" if config.debug else "INFO")
- if config.log_level is None
- else config.log_level
- )
+ default_filter.level = ("DEBUG" if config.debug else "INFO") if config.log_level is None else config.log_level
def text_similarity(str1, str2) -> float:
diff --git a/nonebot_bison/utils/context.py b/nonebot_bison/utils/context.py
index d2eb4cd..9a66390 100644
--- a/nonebot_bison/utils/context.py
+++ b/nonebot_bison/utils/context.py
@@ -1,6 +1,6 @@
from base64 import b64encode
-from httpx import AsyncClient, Response
+from httpx import Response, AsyncClient
class ProcessContext:
@@ -33,8 +33,13 @@ class ProcessContext:
res = []
for req in self.reqs:
if self._should_print_content(req):
- log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} {req.text}"
+ log_content = (
+ f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} {req.text}"
+ )
else:
- log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} b64encoded: {b64encode(req.content[:50]).decode()}"
+ log_content = (
+ f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} "
+ f"b64encoded: {b64encode(req.content[:50]).decode()}"
+ )
res.append(log_content)
return res
diff --git a/nonebot_bison/utils/http.py b/nonebot_bison/utils/http.py
index 6746ff2..08bfb43 100644
--- a/nonebot_bison/utils/http.py
+++ b/nonebot_bison/utils/http.py
@@ -1,5 +1,3 @@
-import functools
-
import httpx
from ..plugin_config import plugin_config
@@ -10,7 +8,6 @@ http_args = {
http_headers = {"user-agent": plugin_config.bison_ua}
-@functools.wraps(httpx.AsyncClient)
def http_client(*args, **kwargs):
if headers := kwargs.get("headers"):
new_headers = http_headers.copy()
@@ -18,7 +15,4 @@ def http_client(*args, **kwargs):
kwargs["headers"] = new_headers
else:
kwargs["headers"] = http_headers
- return httpx.AsyncClient(*args, **kwargs)
-
-
-http_client = functools.partial(http_client, **http_args)
+ return httpx.AsyncClient(*args, **kwargs, **http_args)
diff --git a/nonebot_bison/utils/scheduler_config.py b/nonebot_bison/utils/scheduler_config.py
index 57360fa..25daab9 100644
--- a/nonebot_bison/utils/scheduler_config.py
+++ b/nonebot_bison/utils/scheduler_config.py
@@ -1,4 +1,4 @@
-from typing import Literal, Type
+from typing import Literal
from httpx import AsyncClient
@@ -7,7 +7,6 @@ from .http import http_client
class SchedulerConfig:
-
schedule_type: Literal["date", "interval", "cron"]
schedule_setting: dict
name: str
@@ -25,9 +24,7 @@ class SchedulerConfig:
return self.default_http_client
-def scheduler(
- schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict
-) -> Type[SchedulerConfig]:
+def scheduler(schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict) -> type[SchedulerConfig]:
return type(
"AnonymousScheduleConfig",
(SchedulerConfig,),
diff --git a/pyproject.toml b/pyproject.toml
index 80d3972..a75b117 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -85,7 +85,7 @@ line-length = 120
target-version = "py310"
[tool.black]
-line-length = 120
+line-length = 118
target-version = ["py310", "py311"]
include = '\.pyi?$'
extend-exclude = '''