🎨 按ruff的检查调整程序代码

This commit is contained in:
Azide 2023-07-16 00:22:20 +08:00 committed by felinae98
parent f232ce4c3e
commit dba8f2a9cb
42 changed files with 414 additions and 757 deletions

View File

@ -6,25 +6,18 @@ require("nonebot_plugin_saa")
import nonebot_plugin_saa
from . import (
admin_page,
bootstrap,
config,
platform,
post,
scheduler,
send,
sub_manager,
types,
utils,
)
from .plugin_config import PlugConfig, plugin_config
from . import post, send, types, utils, config, platform, bootstrap, scheduler, admin_page, sub_manager
__help__version__ = "0.7.3"
nonebot_plugin_saa.enable_auto_select_bot()
__help__plugin__name__ = "nonebot_bison"
__usage__ = f"本bot可以提供b站、微博等社交媒体的消息订阅详情请查看本bot文档或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,发送“查询订阅”或“删除订阅”管理订阅"
__usage__ = (
"本bot可以提供b站、微博等社交媒体的消息订阅详情请查看本bot文档"
f"或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,"
f"发送“查询订阅”或“删除订阅”管理订阅"
)
__supported_adapters__ = nonebot_plugin_saa.__plugin_meta__.supported_adapters
@ -41,6 +34,7 @@ __plugin_meta__ = PluginMetadata(
__all__ = [
"admin_page",
"bootstrap",
"config",
"sub_manager",
"post",

View File

@ -1,17 +1,16 @@
import os
from pathlib import Path
from typing import Union
from nonebot import get_driver, on_command
from nonebot.adapters.onebot.v11 import Bot
from nonebot.adapters.onebot.v11.event import PrivateMessageEvent
from nonebot.drivers.fastapi import Driver
from nonebot.log import logger
from nonebot.rule import to_me
from nonebot.typing import T_State
from nonebot import get_driver, on_command
from nonebot.drivers.fastapi import Driver
from nonebot.adapters.onebot.v11 import Bot
from nonebot.adapters.onebot.v11.event import PrivateMessageEvent
from ..plugin_config import plugin_config
from .api import router as api_router
from ..plugin_config import plugin_config
from .token_manager import token_manager as tm
STATIC_PATH = (Path(__file__).parent / "dist").resolve()
@ -28,11 +27,9 @@ def init_fastapi():
class SinglePageApplication(StaticFiles):
def __init__(self, directory: os.PathLike, index="index.html"):
self.index = index
super().__init__(
directory=directory, packages=None, html=True, check_dir=True
)
super().__init__(directory=directory, packages=None, html=True, check_dir=True)
def lookup_path(self, path: str) -> tuple[str, Union[os.stat_result, None]]:
def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]:
full_path, stat_res = super().lookup_path(path)
if stat_res is None:
return super().lookup_path(self.index)
@ -45,9 +42,7 @@ def init_fastapi():
description="nonebot-bison webui and api",
)
nonebot_app.include_router(api_router)
nonebot_app.mount(
"/", SinglePageApplication(directory=static_path), name="bison-frontend"
)
nonebot_app.mount("/", SinglePageApplication(directory=static_path), name="bison-frontend")
app = driver.server_app
app.mount("/bison", nonebot_app, "nonebot-bison")
@ -63,10 +58,9 @@ def init_fastapi():
if host in ["0.0.0.0", "127.0.0.1"]:
host = "localhost"
logger.opt(colors=True).info(
f"Nonebot Bison frontend will be running at: "
f"<b><u>http://{host}:{port}/bison</u></b>"
f"Nonebot Bison frontend will be running at: " f"<b><u>http://{host}:{port}/bison</u></b>"
)
logger.opt(colors=True).info(f"该页面不能被直接访问请私聊bot <b><u>后台管理</u></b> 以获取可访问地址")
logger.opt(colors=True).info("该页面不能被直接访问请私聊bot <b><u>后台管理</u></b> 以获取可访问地址")
def register_get_token_handler():
@ -93,6 +87,4 @@ if (STATIC_PATH / "index.html").exists():
else:
logger.warning("your driver is not fastapi, webui feature will be disabled")
else:
logger.warning(
"Frontend file not found, please compile it or use docker or pypi version"
)
logger.warning("Frontend file not found, please compile it or use docker or pypi version")

View File

@ -1,35 +1,30 @@
import nonebot
from fastapi import status
from fastapi.exceptions import HTTPException
from fastapi.param_functions import Depends
from fastapi.routing import APIRouter
from fastapi.security.oauth2 import OAuth2PasswordBearer
from fastapi.param_functions import Depends
from fastapi.exceptions import HTTPException
from nonebot_plugin_saa import TargetQQGroup
from fastapi.security.oauth2 import OAuth2PasswordBearer
from nonebot_plugin_saa.utils.auto_select_bot import get_bot
from ..apis import check_sub_target
from ..config import (
NoSuchSubscribeException,
NoSuchTargetException,
NoSuchUserException,
config,
)
from ..config.db_config import SubscribeDupException
from ..platform import platform_manager
from ..types import Target as T_Target
from ..types import WeightConfig
from ..utils.get_bot import get_groups
from ..apis import check_sub_target
from .jwt import load_jwt, pack_jwt
from ..types import Target as T_Target
from ..utils.get_bot import get_groups
from ..platform import platform_manager
from .token_manager import token_manager
from ..config.db_config import SubscribeDupException
from ..config import NoSuchUserException, NoSuchTargetException, NoSuchSubscribeException, config
from .types import (
AddSubscribeReq,
TokenResp,
GlobalConf,
PlatformConfig,
StatusResp,
SubscribeResp,
PlatformConfig,
AddSubscribeReq,
SubscribeConfig,
SubscribeGroupDetail,
SubscribeResp,
TokenResp,
)
router = APIRouter(prefix="/api", tags=["api"])
@ -44,9 +39,7 @@ async def get_jwt_obj(token: str = Depends(oath_scheme)):
return obj
async def check_group_permission(
groupNumber: int, token_obj: dict = Depends(get_jwt_obj)
):
async def check_group_permission(groupNumber: int, token_obj: dict = Depends(get_jwt_obj)):
groups = token_obj["groups"]
for group in groups:
if int(groupNumber) == group["id"]:
@ -95,15 +88,13 @@ async def auth(token: str) -> TokenResp:
jwt_obj = {
"id": qq,
"type": "admin",
"groups": list(
map(
lambda info: {
"id": info["group_id"],
"name": info["group_name"],
},
await get_groups(),
)
),
"groups": [
{
"id": info["group_id"],
"name": info["group_name"],
}
for info in await get_groups()
],
}
ret_obj = TokenResp(
type="admin",
@ -134,18 +125,16 @@ async def get_subs_info(jwt_obj: dict = Depends(get_jwt_obj)) -> SubscribeResp:
for group in groups:
group_id = group["id"]
raw_subs = await config.list_subscribe(TargetQQGroup(group_id=group_id))
subs = list(
map(
lambda sub: SubscribeConfig(
platformName=sub.target.platform_name,
targetName=sub.target.target_name,
cats=sub.categories,
tags=sub.tags,
target=sub.target.target,
),
raw_subs,
subs = [
SubscribeConfig(
platformName=sub.target.platform_name,
targetName=sub.target.target_name,
cats=sub.categories,
tags=sub.tags,
target=sub.target.target,
)
)
for sub in raw_subs
]
res[group_id] = SubscribeGroupDetail(name=group["name"], subscribes=subs)
return res
@ -174,9 +163,7 @@ async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp:
@router.delete("/subs", dependencies=[Depends(check_group_permission)])
async def del_group_sub(groupNumber: int, platformName: str, target: str):
try:
await config.del_subscribe(
TargetQQGroup(group_id=groupNumber), target, platformName
)
await config.del_subscribe(TargetQQGroup(group_id=groupNumber), target, platformName)
except (NoSuchUserException, NoSuchSubscribeException):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such user or subscribe")
return StatusResp(ok=True, msg="")
@ -204,13 +191,9 @@ async def get_weight_config():
@router.put("/weight", dependencies=[Depends(check_is_superuser)])
async def update_weigth_config(
platformName: str, target: str, weight_config: WeightConfig
):
async def update_weigth_config(platformName: str, target: str, weight_config: WeightConfig):
try:
await config.update_time_weight_config(
T_Target(target), platformName, weight_config
)
await config.update_time_weight_config(T_Target(target), platformName, weight_config)
except NoSuchTargetException:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such subscribe")
return StatusResp(ok=True, msg="")

View File

@ -1,7 +1,6 @@
import datetime
import random
import string
from typing import Optional
import datetime
import jwt
@ -16,8 +15,8 @@ def pack_jwt(obj: dict) -> str:
)
def load_jwt(token: str) -> Optional[dict]:
def load_jwt(token: str) -> dict | None:
try:
return jwt.decode(token, _key, algorithms=["HS256"])
except:
except Exception:
return None

View File

@ -1,6 +1,5 @@
import random
import string
from typing import Optional
from expiringdict import ExpiringDict
@ -9,7 +8,7 @@ class TokenManager:
def __init__(self):
self.token_manager = ExpiringDict(max_len=100, max_age_seconds=60 * 10)
def get_user(self, token: str) -> Optional[tuple]:
def get_user(self, token: str) -> tuple | None:
res = self.token_manager.get(token)
assert res is None or isinstance(res, tuple)
return res

View File

@ -1,2 +1,4 @@
from .db_config import config
from .utils import NoSuchSubscribeException, NoSuchTargetException, NoSuchUserException
from .db_config import config as config
from .utils import NoSuchUserException as NoSuchUserException
from .utils import NoSuchTargetException as NoSuchTargetException
from .utils import NoSuchSubscribeException as NoSuchSubscribeException

View File

@ -1,20 +1,19 @@
import json
import os
from collections import defaultdict
from datetime import datetime
import json
from os import path
from pathlib import Path
from typing import DefaultDict, Literal, Mapping, TypedDict
from datetime import datetime
from collections import defaultdict
from typing import Literal, TypedDict
import nonebot
from nonebot.log import logger
from tinydb import Query, TinyDB
from ..utils import Singleton
from ..types import User, Target
from ..platform import platform_manager
from ..plugin_config import plugin_config
from ..types import Target, User
from ..utils import Singleton
from .utils import NoSuchSubscribeException, NoSuchUserException
from .utils import NoSuchUserException, NoSuchSubscribeException
supported_target_type = platform_manager.keys()
@ -89,17 +88,16 @@ class Config(metaclass=Singleton):
self.target_user_cat_cache = {}
self.target_user_tag_cache = {}
self.target_list = {}
self.next_index: DefaultDict[str, int] = defaultdict(lambda: 0)
self.next_index: defaultdict[str, int] = defaultdict(lambda: 0)
else:
self.available = False
def add_subscribe(
self, user, user_type, target, target_name, target_type, cats, tags
):
def add_subscribe(self, user, user_type, target, target_name, target_type, cats, tags):
user_query = Query()
query = (user_query.user == user) & (user_query.user_type == user_type)
if user_data := self.user_target.get(query):
# update
assert not isinstance(user_data, list)
subs: list = user_data.get("subs", [])
subs.append(
{
@ -132,9 +130,8 @@ class Config(metaclass=Singleton):
def list_subscribe(self, user, user_type) -> list[SubscribeContent]:
query = Query()
if user_sub := self.user_target.get(
(query.user == user) & (query.user_type == user_type)
):
if user_sub := self.user_target.get((query.user == user) & (query.user_type == user_type)):
assert not isinstance(user_sub, list)
return user_sub["subs"]
return []
@ -146,6 +143,7 @@ class Config(metaclass=Singleton):
query = (user_query.user == user) & (user_query.user_type == user_type)
if not (query_res := self.user_target.get(query)):
raise NoSuchUserException()
assert not isinstance(query_res, list)
subs = query_res.get("subs", [])
for idx, sub in enumerate(subs):
if sub.get("target") == target and sub.get("target_type") == target_type:
@ -155,13 +153,12 @@ class Config(metaclass=Singleton):
return
raise NoSuchSubscribeException()
def update_subscribe(
self, user, user_type, target, target_name, target_type, cats, tags
):
def update_subscribe(self, user, user_type, target, target_name, target_type, cats, tags):
user_query = Query()
query = (user_query.user == user) & (user_query.user_type == user_type)
if user_data := self.user_target.get(query):
# update
assert not isinstance(user_data, list)
subs: list = user_data.get("subs", [])
find_flag = False
for item in subs:
@ -182,19 +179,13 @@ class Config(metaclass=Singleton):
def update_send_cache(self):
res = {target_type: defaultdict(list) for target_type in supported_target_type}
cat_res = {
target_type: defaultdict(lambda: defaultdict(list))
for target_type in supported_target_type
}
tag_res = {
target_type: defaultdict(lambda: defaultdict(list))
for target_type in supported_target_type
}
cat_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
tag_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
# res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
to_del = []
for user in self.user_target.all():
for sub in user.get("subs", []):
if not sub.get("target_type") in supported_target_type:
if sub.get("target_type") not in supported_target_type:
to_del.append(
{
"user": user["user"],
@ -204,36 +195,28 @@ class Config(metaclass=Singleton):
}
)
continue
res[sub["target_type"]][sub["target"]].append(
User(user["user"], user["user_type"])
)
cat_res[sub["target_type"]][sub["target"]][
"{}-{}".format(user["user_type"], user["user"])
] = sub["cats"]
tag_res[sub["target_type"]][sub["target"]][
"{}-{}".format(user["user_type"], user["user"])
] = sub["tags"]
res[sub["target_type"]][sub["target"]].append(User(user["user"], user["user_type"]))
cat_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[
"cats"
]
tag_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[
"tags"
]
self.target_user_cache = res
self.target_user_cat_cache = cat_res
self.target_user_tag_cache = tag_res
for target_type in self.target_user_cache:
self.target_list[target_type] = list(
self.target_user_cache[target_type].keys()
)
self.target_list[target_type] = list(self.target_user_cache[target_type].keys())
logger.info(f"Deleting {to_del}")
for d in to_del:
self.del_subscribe(**d)
def get_sub_category(self, target_type, target, user_type, user):
return self.target_user_cat_cache[target_type][target][
"{}-{}".format(user_type, user)
]
return self.target_user_cat_cache[target_type][target][f"{user_type}-{user}"]
def get_sub_tags(self, target_type, target, user_type, user):
return self.target_user_tag_cache[target_type][target][
"{}-{}".format(user_type, user)
]
return self.target_user_tag_cache[target_type][target][f"{user_type}-{user}"]
def get_next_target(self, target_type):
# FIXME 插入或删除target后对队列的影响但是并不是大问题

View File

@ -1,19 +1,19 @@
import asyncio
from collections import defaultdict
from datetime import datetime, time
from typing import Awaitable, Callable, Optional, Sequence
from datetime import time, datetime
from collections.abc import Callable, Sequence, Awaitable
from nonebot_plugin_datastore import create_session
from nonebot_plugin_saa import PlatformTarget
from sqlalchemy import delete, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import selectinload
from sqlalchemy.exc import IntegrityError
from sqlalchemy import func, delete, select
from nonebot_plugin_saa import PlatformTarget
from nonebot_plugin_datastore import create_session
from ..types import Category, PlatformWeightConfigResp, Tag
from ..types import Tag
from ..types import Target as T_Target
from ..types import TimeWeightConfig, UserSubInfo, WeightConfig
from .db_model import ScheduleTimeWeight, Subscribe, Target, User
from .utils import NoSuchTargetException
from .db_model import User, Target, Subscribe, ScheduleTimeWeight
from ..types import Category, UserSubInfo, WeightConfig, TimeWeightConfig, PlatformWeightConfigResp
def _get_time():
@ -48,23 +48,17 @@ class DBConfig:
):
async with create_session() as session:
db_user_stmt = select(User).where(User.user_target == user.dict())
db_user: Optional[User] = await session.scalar(db_user_stmt)
db_user: User | None = await session.scalar(db_user_stmt)
if not db_user:
db_user = User(user_target=user.dict())
session.add(db_user)
db_target_stmt = (
select(Target)
.where(Target.platform_name == platform_name)
.where(Target.target == target)
select(Target).where(Target.platform_name == platform_name).where(Target.target == target)
)
db_target: Optional[Target] = await session.scalar(db_target_stmt)
db_target: Target | None = await session.scalar(db_target_stmt)
if not db_target:
db_target = Target(
target=target, platform_name=platform_name, target_name=target_name
)
await asyncio.gather(
*[hook(platform_name, target) for hook in self.add_target_hook]
)
db_target = Target(target=target, platform_name=platform_name, target_name=target_name)
await asyncio.gather(*[hook(platform_name, target) for hook in self.add_target_hook])
else:
db_target.target_name = target_name
subscribe = Subscribe(
@ -96,44 +90,25 @@ class DBConfig:
"""获取数据库中带有user、target信息的subscribe数据"""
async with create_session() as session:
query_stmt = (
select(Subscribe)
.join(User)
.options(selectinload(Subscribe.target), selectinload(Subscribe.user))
select(Subscribe).join(User).options(selectinload(Subscribe.target), selectinload(Subscribe.user))
)
subs = (await session.scalars(query_stmt)).all()
return subs
async def del_subscribe(
self, user: PlatformTarget, target: str, platform_name: str
):
async def del_subscribe(self, user: PlatformTarget, target: str, platform_name: str):
async with create_session() as session:
user_obj = await session.scalar(
select(User).where(User.user_target == user.dict())
)
user_obj = await session.scalar(select(User).where(User.user_target == user.dict()))
target_obj = await session.scalar(
select(Target).where(
Target.platform_name == platform_name, Target.target == target
)
)
await session.execute(
delete(Subscribe).where(
Subscribe.user == user_obj, Subscribe.target == target_obj
)
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
await session.execute(delete(Subscribe).where(Subscribe.user == user_obj, Subscribe.target == target_obj))
target_count = await session.scalar(
select(func.count())
.select_from(Subscribe)
.where(Subscribe.target == target_obj)
select(func.count()).select_from(Subscribe).where(Subscribe.target == target_obj)
)
if target_count == 0:
# delete empty target
await asyncio.gather(
*[
hook(platform_name, T_Target(target))
for hook in self.delete_target_hook
]
)
await asyncio.gather(*[hook(platform_name, T_Target(target)) for hook in self.delete_target_hook])
await session.commit()
async def update_subscribe(
@ -165,29 +140,22 @@ class DBConfig:
async def get_platform_target(self, platform_name: str) -> Sequence[Target]:
async with create_session() as sess:
subq = select(Subscribe.target_id).distinct().subquery()
query = (
select(Target).join(subq).where(Target.platform_name == platform_name)
)
query = select(Target).join(subq).where(Target.platform_name == platform_name)
return (await sess.scalars(query)).all()
async def get_time_weight_config(
self, target: T_Target, platform_name: str
) -> WeightConfig:
async def get_time_weight_config(self, target: T_Target, platform_name: str) -> WeightConfig:
async with create_session() as sess:
time_weight_conf = (
await sess.scalars(
select(ScheduleTimeWeight)
.where(
Target.platform_name == platform_name, Target.target == target
)
.where(Target.platform_name == platform_name, Target.target == target)
.join(Target)
)
).all()
targetObj = await sess.scalar(
select(Target).where(
Target.platform_name == platform_name, Target.target == target
)
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
assert targetObj
return WeightConfig(
default=targetObj.default_schedule_weight,
time_config=[
@ -200,22 +168,16 @@ class DBConfig:
],
)
async def update_time_weight_config(
self, target: T_Target, platform_name: str, conf: WeightConfig
):
async def update_time_weight_config(self, target: T_Target, platform_name: str, conf: WeightConfig):
async with create_session() as sess:
targetObj = await sess.scalar(
select(Target).where(
Target.platform_name == platform_name, Target.target == target
)
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
if not targetObj:
raise NoSuchTargetException()
target_id = targetObj.id
targetObj.default_schedule_weight = conf.default
delete_statement = delete(ScheduleTimeWeight).where(
ScheduleTimeWeight.target_id == target_id
)
delete_statement = delete(ScheduleTimeWeight).where(ScheduleTimeWeight.target_id == target_id)
await sess.execute(delete_statement)
for time_conf in conf.time_config:
new_conf = ScheduleTimeWeight(
@ -243,18 +205,13 @@ class DBConfig:
key = f"{target.platform_name}-{target.target}"
weight = target.default_schedule_weight
for time_conf in target.time_weight:
if (
time_conf.start_time <= cur_time
and time_conf.end_time > cur_time
):
if time_conf.start_time <= cur_time and time_conf.end_time > cur_time:
weight = time_conf.weight
break
res[key] = weight
return res
async def get_platform_target_subscribers(
self, platform_name: str, target: T_Target
) -> list[UserSubInfo]:
async def get_platform_target_subscribers(self, platform_name: str, target: T_Target) -> list[UserSubInfo]:
async with create_session() as sess:
query = (
select(Subscribe)
@ -263,16 +220,14 @@ class DBConfig:
.options(selectinload(Subscribe.user))
)
subsribes = (await sess.scalars(query)).all()
return list(
map(
lambda subscribe: UserSubInfo(
PlatformTarget.deserialize(subscribe.user.user_target),
subscribe.categories,
subscribe.tags,
),
subsribes,
return [
UserSubInfo(
PlatformTarget.deserialize(subscribe.user.user_target),
subscribe.categories,
subscribe.tags,
)
)
for subscribe in subsribes
]
async def get_all_weight_config(
self,
@ -281,9 +236,7 @@ class DBConfig:
async with create_session() as sess:
query = select(Target)
targets = (await sess.scalars(query)).all()
query = select(ScheduleTimeWeight).options(
selectinload(ScheduleTimeWeight.target)
)
query = select(ScheduleTimeWeight).options(selectinload(ScheduleTimeWeight.target))
time_weights = (await sess.scalars(query)).all()
for target in targets:
@ -293,9 +246,7 @@ class DBConfig:
target=T_Target(target.target),
target_name=target.target_name,
platform_name=platform_name,
weight=WeightConfig(
default=target.default_schedule_weight, time_config=[]
),
weight=WeightConfig(default=target.default_schedule_weight, time_config=[]),
)
for time_weight_config in time_weights:

View File

@ -1,22 +1,17 @@
from nonebot.log import logger
from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from sqlalchemy.ext.asyncio.session import AsyncSession
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from .db_model import User, Target, Subscribe
from .config_legacy import Config, ConfigContent, drop
from .db_model import Subscribe, Target, User
async def data_migrate():
config = Config()
if config.available:
logger.warning("You are still using legacy db, migrating to sqlite")
all_subs: list[ConfigContent] = list(
map(
lambda item: ConfigContent(**item),
config.get_all_subscribe().all(),
)
)
all_subs: list[ConfigContent] = [ConfigContent(**item) for item in config.get_all_subscribe().all()]
async with AsyncSession(get_engine()) as sess:
user_to_create = []
subscribe_to_create = []
@ -37,8 +32,7 @@ async def data_migrate():
if key in user_sub_set:
# a user subscribe a target twice
logger.error(
f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次,"
"随机采用了一个订阅"
f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次,随机采用了一个订阅" # noqa: E501
)
continue
user_sub_set.add(key)
@ -69,11 +63,7 @@ async def data_migrate():
tags=sub["tags"],
)
subscribe_to_create.append(subscribe_obj)
sess.add_all(
user_to_create
+ list(map(lambda x: x[0], platform_target_map.values()))
+ subscribe_to_create
)
sess.add_all(user_to_create + [x[0] for x in platform_target_map.values()] + subscribe_to_create)
await sess.commit()
drop()
logger.info("migrate success")

View File

@ -1,7 +1,7 @@
"""init db
Revision ID: 0571870f5222
Revises:
Revises:
Create Date: 2022-03-21 19:18:13.762626
"""

View File

@ -5,7 +5,6 @@ Revises: 5f3370328e44
Create Date: 2023-01-15 19:04:54.987491
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.

View File

@ -5,7 +5,6 @@ Revises: 0571870f5222
Create Date: 2022-03-26 19:46:50.910721
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
@ -18,14 +17,10 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("subscribe", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique-subscribe-constraint", ["target_id", "user_id"]
)
batch_op.create_unique_constraint("unique-subscribe-constraint", ["target_id", "user_id"])
with op.batch_alter_table("target", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique-target-constraint", ["target", "platform_name"]
)
batch_op.create_unique_constraint("unique-target-constraint", ["target", "platform_name"])
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"])

View File

@ -1,15 +1,14 @@
from abc import ABC
from nonebot_plugin_saa.utils import AllSupportedPlatformTarget as UserInfo
from pydantic import BaseModel
from nonebot_plugin_saa.utils import AllSupportedPlatformTarget as UserInfo
from ....types import Category, Tag
from ....types import Tag, Category
class NBESFBase(BaseModel, ABC):
version: int # 表示nbesf格式版本有效版本从1开始
groups: list = list()
groups: list = []
class Config:
orm_mode = True

View File

@ -1,25 +1,26 @@
from typing import cast
from collections import defaultdict
from typing import Callable, cast
from collections.abc import Callable
from nonebot.log import logger
from nonebot_plugin_datastore.db import create_session
from nonebot_plugin_saa import PlatformTarget
from sqlalchemy import select
from sqlalchemy.orm.strategy_options import selectinload
from nonebot.log import logger
from sqlalchemy.sql.selectable import Select
from nonebot_plugin_saa import PlatformTarget
from nonebot_plugin_datastore.db import create_session
from sqlalchemy.orm.strategy_options import selectinload
from ..db_model import Subscribe, User
from .nbesf_model import NBESFBase, v1, v2
from .utils import NBESFVerMatchErr
from ..db_model import User, Subscribe
from .nbesf_model import NBESFBase, v1, v2
async def subscribes_export(selector: Callable[[Select], Select]) -> v2.SubGroup:
"""
将Bison订阅导出为 Nonebot Bison Exchangable Subscribes File 标准格式的 SubGroup 类型数据
selector:
sqlalchemy Select 对象的操作函数用于限定查询范围 e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group")
sqlalchemy Select 对象的操作函数用于限定查询范围
e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group")
"""
async with create_session() as sess:
sub_stmt = select(Subscribe).join(User)

View File

@ -1,23 +1,22 @@
from collections import defaultdict
from importlib import import_module
from pathlib import Path
from pkgutil import iter_modules
from typing import DefaultDict, Type
from collections import defaultdict
from importlib import import_module
from .platform import Platform, make_no_target_group
_package_dir = str(Path(__file__).resolve().parent)
for (_, module_name, _) in iter_modules([_package_dir]):
for _, module_name, _ in iter_modules([_package_dir]):
import_module(f"{__name__}.{module_name}")
_platform_list: DefaultDict[str, list[Type[Platform]]] = defaultdict(list)
_platform_list: defaultdict[str, list[type[Platform]]] = defaultdict(list)
for _platform in Platform.registry:
if not _platform.enabled:
continue
_platform_list[_platform.platform_name].append(_platform)
platform_manager: dict[str, Type[Platform]] = dict()
platform_manager: dict[str, type[Platform]] = {}
for name, platform_list in _platform_list.items():
if len(platform_list) == 1:
platform_manager[name] = platform_list[0]

View File

@ -1,25 +1,23 @@
import json
from typing import Any, Optional
from typing import Any
from bs4 import BeautifulSoup as bs
from httpx import AsyncClient
from nonebot.plugin import require
from bs4 import BeautifulSoup as bs
from ..post import Post
from ..types import Category, RawPost, Target
from ..types import Target, RawPost, Category
from ..utils.scheduler_config import SchedulerConfig
from .platform import CategoryNotRecognize, NewMessage, StatusChange
from .platform import NewMessage, StatusChange, CategoryNotRecognize
class ArknightsSchedConf(SchedulerConfig):
name = "arknights"
schedule_type = "interval"
schedule_setting = {"seconds": 30}
class Arknights(NewMessage):
categories = {1: "游戏公告"}
platform_name = "arknights"
name = "明日方舟游戏信息"
@ -30,9 +28,7 @@ class Arknights(NewMessage):
has_target = False
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]:
@ -92,7 +88,6 @@ class Arknights(NewMessage):
class AkVersion(StatusChange):
categories = {2: "更新信息"}
platform_name = "arknights"
name = "明日方舟游戏信息"
@ -103,15 +98,11 @@ class AkVersion(StatusChange):
has_target = False
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
return "明日方舟游戏信息"
async def get_status(self, _):
res_ver = await self.client.get(
"https://ak-conf.hypergryph.com/config/prod/official/IOS/version"
)
res_ver = await self.client.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version")
res_preanounce = await self.client.get(
"https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json"
)
@ -121,20 +112,10 @@ class AkVersion(StatusChange):
def compare_status(self, _, old_status, new_status):
res = []
if (
old_status.get("preAnnounceType") == 2
and new_status.get("preAnnounceType") == 0
):
res.append(
Post("arknights", text="登录界面维护公告上线(大概是开始维护了)", target_name="明日方舟更新信息")
)
elif (
old_status.get("preAnnounceType") == 0
and new_status.get("preAnnounceType") == 2
):
res.append(
Post("arknights", text="登录界面维护公告下线(大概是开服了,冲!)", target_name="明日方舟更新信息")
)
if old_status.get("preAnnounceType") == 2 and new_status.get("preAnnounceType") == 0:
res.append(Post("arknights", text="登录界面维护公告上线(大概是开始维护了)", target_name="明日方舟更新信息")) # noqa: E501
elif old_status.get("preAnnounceType") == 0 and new_status.get("preAnnounceType") == 2:
res.append(Post("arknights", text="登录界面维护公告下线(大概是开服了,冲!)", target_name="明日方舟更新信息")) # noqa: E501
if old_status.get("clientVersion") != new_status.get("clientVersion"):
res.append(Post("arknights", text="游戏本体更新(大更新)", target_name="明日方舟更新信息"))
if old_status.get("resVersion") != new_status.get("resVersion"):
@ -149,7 +130,6 @@ class AkVersion(StatusChange):
class MonsterSiren(NewMessage):
categories = {3: "塞壬唱片新闻"}
platform_name = "arknights"
name = "明日方舟游戏信息"
@ -160,15 +140,11 @@ class MonsterSiren(NewMessage):
has_target = False
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]:
raw_data = await self.client.get(
"https://monster-siren.hypergryph.com/api/news"
)
raw_data = await self.client.get("https://monster-siren.hypergryph.com/api/news")
return raw_data.json()["data"]["list"]
def get_id(self, post: RawPost) -> Any:
@ -182,14 +158,12 @@ class MonsterSiren(NewMessage):
async def parse(self, raw_post: RawPost) -> Post:
url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}'
res = await self.client.get(
f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}'
)
res = await self.client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}')
raw_data = res.json()
content = raw_data["data"]["content"]
content = content.replace("</p>", "</p>\n")
soup = bs(content, "html.parser")
imgs = list(map(lambda x: x["src"], soup("img")))
imgs = [x["src"] for x in soup("img")]
text = f'{raw_post["title"]}\n{soup.text.strip()}'
return Post(
"monster-siren",
@ -203,7 +177,6 @@ class MonsterSiren(NewMessage):
class TerraHistoricusComic(NewMessage):
categories = {4: "泰拉记事社漫画"}
platform_name = "arknights"
name = "明日方舟游戏信息"
@ -214,15 +187,11 @@ class TerraHistoricusComic(NewMessage):
has_target = False
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]:
raw_data = await self.client.get(
"https://terra-historicus.hypergryph.com/api/recentUpdate"
)
raw_data = await self.client.get("https://terra-historicus.hypergryph.com/api/recentUpdate")
return raw_data.json()["data"]
def get_id(self, post: RawPost) -> Any:

View File

@ -1,14 +1,14 @@
import json
import re
import json
from typing import Any
from copy import deepcopy
from datetime import datetime, timedelta
from enum import Enum, unique
from typing import Any, Literal, Optional
from typing_extensions import Self
from datetime import datetime, timedelta
from httpx import AsyncClient
from nonebot.log import logger
from pydantic import BaseModel, Field
from typing_extensions import Self
from pydantic import Field, BaseModel
from ..post import Post
from ..types import ApiError, Category, RawPost, Tag, Target
@ -25,9 +25,7 @@ class BilibiliSchedConf(SchedulerConfig):
cookie_expire_time = timedelta(hours=5)
def __init__(self):
self._client_refresh_time = datetime(
year=2000, month=1, day=1
) # an expired time
self._client_refresh_time = datetime(year=2000, month=1, day=1) # an expired time
super().__init__()
async def _init_session(self):
@ -69,12 +67,8 @@ class Bilibili(NewMessage):
parse_target_promot = "请输入用户主页的链接"
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
res = await client.get(
"https://api.bilibili.com/x/web-interface/card", params={"mid": target}
)
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
res = await client.get("https://api.bilibili.com/x/web-interface/card", params={"mid": target})
res.raise_for_status()
res_data = res.json()
if res_data["code"]:
@ -129,12 +123,7 @@ class Bilibili(NewMessage):
return self._do_get_category(post_type)
def get_tags(self, raw_post: RawPost) -> list[Tag]:
return [
*map(
lambda tp: tp["topic_name"],
raw_post["display"]["topic_info"]["topic_details"],
)
]
return [*(tp["topic_name"] for tp in raw_post["display"]["topic_info"]["topic_details"])]
def _get_info(self, post_type: Category, card) -> tuple[str, list]:
if post_type == 1:
@ -178,24 +167,16 @@ class Bilibili(NewMessage):
url = ""
if post_type == 1:
# 一般动态
url = "https://t.bilibili.com/{}".format(
raw_post["desc"]["dynamic_id_str"]
)
url = "https://t.bilibili.com/{}".format(raw_post["desc"]["dynamic_id_str"])
elif post_type == 2:
# 专栏文章
url = "https://www.bilibili.com/read/cv{}".format(
raw_post["desc"]["rid"]
)
url = "https://www.bilibili.com/read/cv{}".format(raw_post["desc"]["rid"])
elif post_type == 3:
# 视频
url = "https://www.bilibili.com/video/{}".format(
raw_post["desc"]["bvid"]
)
url = "https://www.bilibili.com/video/{}".format(raw_post["desc"]["bvid"])
elif post_type == 4:
# 纯文字
url = "https://t.bilibili.com/{}".format(
raw_post["desc"]["dynamic_id_str"]
)
url = "https://t.bilibili.com/{}".format(raw_post["desc"]["dynamic_id_str"])
text, pic = self._get_info(post_type, card_content)
elif post_type == 5:
# 转发
@ -261,10 +242,7 @@ class Bilibililive(StatusChange):
def get_live_action(self, old_info: Self) -> "Bilibililive.LiveAction":
status = Bilibililive.LiveStatus
action = Bilibililive.LiveAction
if (
old_info.live_status in [status.OFF, status.CYCLE]
and self.live_status == status.ON
):
if old_info.live_status in [status.OFF, status.CYCLE] and self.live_status == status.ON:
return action.TURN_ON
elif old_info.live_status == status.ON and self.live_status in [
status.OFF,
@ -281,12 +259,8 @@ class Bilibililive(StatusChange):
return action.OFF
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
res = await client.get(
"https://api.bilibili.com/x/web-interface/card", params={"mid": target}
)
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
res = await client.get("https://api.bilibili.com/x/web-interface/card", params={"mid": target})
res_data = json.loads(res.text)
if res_data["code"]:
return None
@ -382,9 +356,7 @@ class BilibiliBangumi(StatusChange):
_url = "https://api.bilibili.com/pgc/review/user"
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
res = await client.get(cls._url, params={"media_id": target})
res_data = res.json()
if res_data["code"]:

View File

@ -1,15 +1,14 @@
from typing import Any, Optional
from typing import Any
from httpx import AsyncClient
from ..post import Post
from ..types import RawPost, Target
from ..utils import scheduler
from .platform import NewMessage
from ..types import Target, RawPost
class FF14(NewMessage):
categories = {}
platform_name = "ff14"
name = "最终幻想XIV官方公告"
@ -21,9 +20,7 @@ class FF14(NewMessage):
has_target = False
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
return "最终幻想XIV官方公告"
async def get_sub_list(self, _) -> list[RawPost]:

View File

@ -2,15 +2,15 @@ import re
import time
import traceback
from bs4 import BeautifulSoup, Tag
from httpx import AsyncClient
from nonebot.log import logger
from bs4 import Tag, BeautifulSoup
from nonebot.plugin import require
from ..post import Post
from ..types import Category, RawPost, Target
from ..types import Target, RawPost, Category
from ..utils import SchedulerConfig, http_client
from .platform import CategoryNotRecognize, CategoryNotSupport, NewMessage
from .platform import NewMessage, CategoryNotSupport, CategoryNotRecognize
class McbbsnewsSchedConf(SchedulerConfig):
@ -134,9 +134,9 @@ class McbbsNews(NewMessage):
if categoty_name in category_values:
category_id = category_keys[category_values.index(categoty_name)]
elif categoty_name in known_category_values:
raise CategoryNotSupport("McbbsNews订阅暂不支持 {}".format(categoty_name))
raise CategoryNotSupport(f"McbbsNews订阅暂不支持 {categoty_name}")
else:
raise CategoryNotRecognize("Mcbbsnews订阅尚未识别 {}".format(categoty_name))
raise CategoryNotRecognize(f"Mcbbsnews订阅尚未识别 {categoty_name}")
return category_id
async def parse(self, post: RawPost) -> Post:
@ -170,7 +170,7 @@ class McbbsNews(NewMessage):
一般而言每条新闻的长度都很可观图片生成时间比较喜人
"""
require("nonebot_plugin_htmlrender")
from nonebot_plugin_htmlrender import capture_element, text_to_pic
from nonebot_plugin_htmlrender import text_to_pic, capture_element
try:
assert url
@ -181,7 +181,7 @@ class McbbsNews(NewMessage):
device_scale_factor=3,
)
assert pic_data
except:
except Exception:
err_info = traceback.format_exc()
logger.warning(f"渲染错误:{err_info}")

View File

@ -1,23 +1,21 @@
import re
from typing import Any, Optional
from typing import Any
from httpx import AsyncClient
from ..post import Post
from ..types import ApiError, RawPost, Target
from ..utils import SchedulerConfig
from .platform import NewMessage
from ..utils import SchedulerConfig
from ..types import Target, RawPost, ApiError
class NcmSchedConf(SchedulerConfig):
name = "music.163.com"
schedule_type = "interval"
schedule_setting = {"minutes": 1}
class NcmArtist(NewMessage):
categories = {}
platform_name = "ncm-artist"
enable_tag = False
@ -29,11 +27,9 @@ class NcmArtist(NewMessage):
parse_target_promot = "请输入歌手主页包含数字ID的链接"
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
res = await client.get(
"https://music.163.com/api/artist/albums/{}".format(target),
f"https://music.163.com/api/artist/albums/{target}",
headers={"Referer": "https://music.163.com/"},
)
res_data = res.json()
@ -45,16 +41,14 @@ class NcmArtist(NewMessage):
async def parse_target(cls, target_text: str) -> Target:
if re.match(r"^\d+$", target_text):
return Target(target_text)
elif match := re.match(
r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text
):
elif match := re.match(r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text):
return Target(match.group(1))
else:
raise cls.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]:
res = await self.client.get(
"https://music.163.com/api/artist/albums/{}".format(target),
f"https://music.163.com/api/artist/albums/{target}",
headers={"Referer": "https://music.163.com/"},
)
res_data = res.json()
@ -74,13 +68,10 @@ class NcmArtist(NewMessage):
target_name = raw_post["artist"]["name"]
pics = [raw_post["picUrl"]]
url = "https://music.163.com/#/album?id={}".format(raw_post["id"])
return Post(
"ncm-artist", text=text, url=url, pics=pics, target_name=target_name
)
return Post("ncm-artist", text=text, url=url, pics=pics, target_name=target_name)
class NcmRadio(NewMessage):
categories = {}
platform_name = "ncm-radio"
enable_tag = False
@ -92,9 +83,7 @@ class NcmRadio(NewMessage):
parse_target_promot = "请输入主播电台主页包含数字ID的链接"
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
res = await client.post(
"http://music.163.com/api/dj/program/byradio",
headers={"Referer": "https://music.163.com/"},
@ -109,9 +98,7 @@ class NcmRadio(NewMessage):
async def parse_target(cls, target_text: str) -> Target:
if re.match(r"^\d+$", target_text):
return Target(target_text)
elif match := re.match(
r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text
):
elif match := re.match(r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text):
return Target(match.group(1))
else:
raise cls.ParseTargetException()

View File

@ -1,29 +1,32 @@
import json
import ssl
import json
import time
import typing
from typing import Any
from dataclasses import dataclass
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Collection, Optional, Type
from collections.abc import Collection
import httpx
from httpx import AsyncClient
from nonebot.log import logger
from nonebot_plugin_saa import PlatformTarget
from ..plugin_config import plugin_config
from ..post import Post
from ..types import Category, RawPost, Tag, Target, UserSubInfo
from ..plugin_config import plugin_config
from ..utils import ProcessContext, SchedulerConfig
from ..types import Tag, Target, RawPost, Category, UserSubInfo
class CategoryNotSupport(Exception):
"raise in get_category, when you know the category of the post but don't want to support it or don't support its parsing yet"
"""raise in get_category, when you know the category of the post
but don't want to support it or don't support its parsing yet
"""
class CategoryNotRecognize(Exception):
"raise in get_category, when you don't know the category of post"
"""raise in get_category, when you don't know the category of post"""
class RegistryMeta(type):
@ -42,7 +45,6 @@ class RegistryMeta(type):
class PlatformMeta(RegistryMeta):
categories: dict[Category, str]
store: dict[Target, Any]
@ -60,8 +62,7 @@ class PlatformABCMeta(PlatformMeta, ABC):
class Platform(metaclass=PlatformABCMeta, base=True):
scheduler: Type[SchedulerConfig]
scheduler: type[SchedulerConfig]
ctx: ProcessContext
is_common: bool
enabled: bool
@ -70,16 +71,14 @@ class Platform(metaclass=PlatformABCMeta, base=True):
categories: dict[Category, str]
enable_tag: bool
platform_name: str
parse_target_promot: Optional[str] = None
registry: list[Type["Platform"]]
parse_target_promot: str | None = None
registry: list[type["Platform"]]
client: AsyncClient
reverse_category: dict[str, Category]
@classmethod
@abstractmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
...
@abstractmethod
@ -95,11 +94,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
return await self.fetch_new_post(target, users)
except httpx.RequestError as err:
if plugin_config.bison_show_network_warning:
logger.warning(
"network connection error: {}, url: {}".format(
type(err), err.request.url
)
)
logger.warning(f"network connection error: {type(err)}, url: {err.request.url}")
return []
except ssl.SSLError as err:
if plugin_config.bison_show_network_warning:
@ -130,7 +125,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
return Target(target_string)
@abstractmethod
def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]:
def get_tags(self, raw_post: RawPost) -> Collection[Tag] | None:
"Return Tag list of given RawPost"
@classmethod
@ -201,9 +196,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
) -> list[tuple[PlatformTarget, list[Post]]]:
res: list[tuple[PlatformTarget, list[Post]]] = []
for user, cats, required_tags in users:
user_raw_post = await self.filter_user_custom(
new_posts, cats, required_tags
)
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
user_post: list[Post] = []
for raw_post in user_raw_post:
user_post.append(await self.do_parse(raw_post))
@ -211,7 +204,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
return res
@abstractmethod
def get_category(self, post: RawPost) -> Optional[Category]:
def get_category(self, post: RawPost) -> Category | None:
"Return category of given Rawpost"
raise NotImplementedError()
@ -221,7 +214,7 @@ class MessageProcess(Platform, abstract=True):
def __init__(self, ctx: ProcessContext, client: AsyncClient):
super().__init__(ctx, client)
self.parse_cache: dict[Any, Post] = dict()
self.parse_cache: dict[Any, Post] = {}
@abstractmethod
def get_id(self, post: RawPost) -> Any:
@ -246,7 +239,7 @@ class MessageProcess(Platform, abstract=True):
"Get post list of the given target"
@abstractmethod
def get_date(self, post: RawPost) -> Optional[int]:
def get_date(self, post: RawPost) -> int | None:
"Get post timestamp and return, return None if can't get the time"
async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]:
@ -286,9 +279,7 @@ class NewMessage(MessageProcess, abstract=True):
inited: bool
exists_posts: set[Any]
async def filter_common_with_diff(
self, target: Target, raw_post_list: list[RawPost]
) -> list[RawPost]:
async def filter_common_with_diff(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]:
filtered_post = await self.filter_common(raw_post_list)
store = self.get_stored_data(target) or self.MessageStorage(False, set())
res = []
@ -297,11 +288,7 @@ class NewMessage(MessageProcess, abstract=True):
for raw_post in filtered_post:
post_id = self.get_id(raw_post)
store.exists_posts.add(post_id)
logger.info(
"init {}-{} with {}".format(
self.platform_name, target, store.exists_posts
)
)
logger.info(f"init {self.platform_name}-{target} with {store.exists_posts}")
store.inited = True
else:
for raw_post in filtered_post:
@ -400,12 +387,11 @@ class SimplePost(MessageProcess, abstract=True):
return res
def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]:
if typing.TYPE_CHECKING:
class NoTargetGroup(Platform, abstract=True):
platform_list: list[Type[Platform]]
platform_list: list[type[Platform]]
platform_obj_list: list[Platform]
DUMMY_STR = "_DUMMY"
@ -418,24 +404,18 @@ def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
for platform in platform_list:
if platform.has_target:
raise RuntimeError(
"Platform {} should have no target".format(platform.name)
)
raise RuntimeError(f"Platform {platform.name} should have no target")
if name == DUMMY_STR:
name = platform.name
elif name != platform.name:
raise RuntimeError("Platform name for {} not fit".format(platform_name))
raise RuntimeError(f"Platform name for {platform_name} not fit")
platform_category_key_set = set(platform.categories.keys())
if platform_category_key_set & categories_keys:
raise RuntimeError(
"Platform categories for {} duplicate".format(platform_name)
)
raise RuntimeError(f"Platform categories for {platform_name} duplicate")
categories_keys |= platform_category_key_set
categories.update(platform.categories)
if platform.scheduler != scheduler:
raise RuntimeError(
"Platform scheduler for {} not fit".format(platform_name)
)
raise RuntimeError(f"Platform scheduler for {platform_name} not fit")
def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient):
Platform.__init__(self, ctx, client)
@ -444,15 +424,13 @@ def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
self.platform_obj_list.append(platform_class(ctx, client))
def __str__(self: "NoTargetGroup") -> str:
return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]"
return "[" + " ".join(x.name for x in self.platform_list) + "]"
@classmethod
async def get_target_name(cls, client: AsyncClient, target: Target):
return await platform_list[0].get_target_name(client, target)
async def fetch_new_post(
self: "NoTargetGroup", target: Target, users: list[UserSubInfo]
):
async def fetch_new_post(self: "NoTargetGroup", target: Target, users: list[UserSubInfo]):
res = defaultdict(list)
for platform in self.platform_obj_list:
platform_res = await platform.fetch_new_post(target=target, users=users)

View File

@ -1,10 +1,10 @@
import calendar
import time
from typing import Any, Optional
from typing import Any
import feedparser
from bs4 import BeautifulSoup as bs
from httpx import AsyncClient
from bs4 import BeautifulSoup as bs
from ..post import Post
from ..types import RawPost, Target
@ -20,7 +20,6 @@ class RssSchedConf(SchedulerConfig):
class Rss(NewMessage):
categories = {}
enable_tag = False
platform_name = "rss"
@ -31,9 +30,7 @@ class Rss(NewMessage):
has_target = True
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
res = await client.get(target, timeout=10.0)
feed = feedparser.parse(res.text)
return feed["feed"]["title"]
@ -69,7 +66,7 @@ class Rss(NewMessage):
else:
text = f"{title}\n\n{desc}"
pics = list(map(lambda x: x.attrs["src"], soup("img")))
pics = [x.attrs["src"] for x in soup("img")]
if raw_post.get("media_content"):
for media in raw_post["media_content"]:
if media.get("medium") == "image" and media.get("url"):

View File

@ -1,17 +1,16 @@
import json
import re
from collections.abc import Callable
import json
from typing import Any
from datetime import datetime
from typing import Any, Optional
from bs4 import BeautifulSoup as bs
from httpx import AsyncClient
from nonebot.log import logger
from bs4 import BeautifulSoup as bs
from ..post import Post
from ..types import *
from ..utils import SchedulerConfig, http_client
from .platform import NewMessage
from ..utils import SchedulerConfig, http_client
from ..types import Tag, Target, RawPost, ApiError, Category
class WeiboSchedConf(SchedulerConfig):
@ -21,7 +20,6 @@ class WeiboSchedConf(SchedulerConfig):
class Weibo(NewMessage):
categories = {
1: "转发",
2: "视频",
@ -38,13 +36,9 @@ class Weibo(NewMessage):
parse_target_promot = "请输入用户主页包含数字UID的链接"
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
param = {"containerid": "100505" + target}
res = await client.get(
"https://m.weibo.cn/api/container/getIndex", params=param
)
res = await client.get("https://m.weibo.cn/api/container/getIndex", params=param)
res_dict = json.loads(res.text)
if res_dict.get("ok") == 1:
return res_dict["data"]["userInfo"]["screen_name"]
@ -63,13 +57,14 @@ class Weibo(NewMessage):
async def get_sub_list(self, target: Target) -> list[RawPost]:
params = {"containerid": "107603" + target}
res = await self.client.get(
"https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0
)
res = await self.client.get("https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0)
res_data = json.loads(res.text)
if not res_data["ok"] and res_data["msg"] != "这里还没有内容":
raise ApiError(res.request.url)
custom_filter: Callable[[RawPost], bool] = lambda d: d["card_type"] == 9
def custom_filter(d: RawPost) -> bool:
return d["card_type"] == 9
return list(filter(custom_filter, res_data["data"]["cards"]))
def get_id(self, post: RawPost) -> Any:
@ -79,44 +74,32 @@ class Weibo(NewMessage):
return raw_post["card_type"] == 9
def get_date(self, raw_post: RawPost) -> float:
created_time = datetime.strptime(
raw_post["mblog"]["created_at"], "%a %b %d %H:%M:%S %z %Y"
)
created_time = datetime.strptime(raw_post["mblog"]["created_at"], "%a %b %d %H:%M:%S %z %Y")
return created_time.timestamp()
def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]:
def get_tags(self, raw_post: RawPost) -> list[Tag] | None:
"Return Tag list of given RawPost"
text = raw_post["mblog"]["text"]
soup = bs(text, "html.parser")
res = list(
map(
lambda x: x[1:-1],
filter(
lambda s: s[0] == "#" and s[-1] == "#",
map(lambda x: x.text, soup.find_all("span", class_="surl-text")),
),
res = [
x[1:-1]
for x in filter(
lambda s: s[0] == "#" and s[-1] == "#",
(x.text for x in soup.find_all("span", class_="surl-text")),
)
)
super_topic_img = soup.find(
"img", src=re.compile(r"timeline_card_small_super_default")
)
]
super_topic_img = soup.find("img", src=re.compile(r"timeline_card_small_super_default"))
if super_topic_img:
try:
res.append(
super_topic_img.parent.parent.find("span", class_="surl-text").text # type: ignore
+ "超话"
)
except:
logger.info("super_topic extract error: {}".format(text))
res.append(super_topic_img.parent.parent.find("span", class_="surl-text").text + "超话") # type: ignore
except Exception:
logger.info(f"super_topic extract error: {text}")
return res
def get_category(self, raw_post: RawPost) -> Category:
if raw_post["mblog"].get("retweeted_status"):
return Category(1)
elif (
raw_post["mblog"].get("page_info")
and raw_post["mblog"]["page_info"].get("type") == "video"
):
elif raw_post["mblog"].get("page_info") and raw_post["mblog"]["page_info"].get("type") == "video":
return Category(2)
elif raw_post["mblog"].get("pics"):
return Category(3)
@ -129,7 +112,8 @@ class Weibo(NewMessage):
async def parse(self, raw_post: RawPost) -> Post:
header = {
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,"
"*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
"accept-language": "zh-CN,zh;q=0.9",
"authority": "m.weibo.cn",
"cache-control": "max-age=0",
@ -147,26 +131,16 @@ class Weibo(NewMessage):
retweeted = True
pic_num = info["retweeted_status"]["pic_num"] if retweeted else info["pic_num"]
if info["isLongText"] or pic_num > 9:
res = await self.client.get(
"https://m.weibo.cn/detail/{}".format(info["mid"]), headers=header
)
res = await self.client.get(f"https://m.weibo.cn/detail/{info['mid']}", headers=header)
try:
match = re.search(r'"status": ([\s\S]+),\s+"call"', res.text)
assert match
full_json_text = match.group(1)
info = json.loads(full_json_text)
except:
logger.info(
"detail message error: https://m.weibo.cn/detail/{}".format(
info["mid"]
)
)
except Exception:
logger.info(f"detail message error: https://m.weibo.cn/detail/{info['mid']}")
parsed_text = self._get_text(info["text"])
raw_pics_list = (
info["retweeted_status"].get("pics", [])
if retweeted
else info.get("pics", [])
)
raw_pics_list = info["retweeted_status"].get("pics", []) if retweeted else info.get("pics", [])
pic_urls = [img["large"]["url"] for img in raw_pics_list]
pics = []
for pic_url in pic_urls:
@ -174,7 +148,7 @@ class Weibo(NewMessage):
res = await client.get(pic_url)
res.raise_for_status()
pics.append(res.content)
detail_url = "https://weibo.com/{}/{}".format(info["user"]["id"], info["bid"])
detail_url = f"https://weibo.com/{info['user']['id']}/{info['bid']}"
# return parsed_text, detail_url, pic_urls
return Post(
"weibo",

View File

@ -1,11 +1,8 @@
from typing import Optional
import nonebot
from pydantic import BaseSettings
class PlugConfig(BaseSettings):
bison_config_path: str = ""
bison_use_pic: bool = False
bison_init_filter: bool = True
@ -17,8 +14,12 @@ class PlugConfig(BaseSettings):
bison_use_pic_merge: int = 0 # 多图片时启用图片合并转发(仅限群)
# 0不启用1首条消息单独发送剩余照片合并转发2以及以上所有消息全部合并转发
bison_resend_times: int = 0
bison_proxy: Optional[str]
bison_ua: str = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.103 Safari/537.36"
bison_proxy: str | None
bison_ua: str = (
"Mozilla/5.0 (X11; Linux x86_64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/51.0.2704.103 Safari/537.36"
)
bison_show_network_warning: bool = True
class Config:

View File

@ -1,3 +1,3 @@
from .post import Post
__all__ = ["Post", "CustomPost"]
__all__ = ["Post"]

View File

@ -1,7 +1,6 @@
from abc import abstractmethod
from dataclasses import dataclass, field
from functools import reduce
from typing import Optional
from abc import abstractmethod
from dataclasses import field, dataclass
from nonebot_plugin_saa import MessageFactory, MessageSegmentFactory
@ -25,12 +24,12 @@ class BasePost:
class OptionalMixin:
# Because of https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
override_use_pic: Optional[bool] = None
override_use_pic: bool | None = None
compress: bool = False
extra_msg: list[MessageFactory] = field(default_factory=list)
def _use_pic(self):
if not self.override_use_pic is None:
if self.override_use_pic is not None:
return self.override_use_pic
return plugin_config.bison_use_pic
@ -44,13 +43,9 @@ class AbstractPost(OptionalMixin, BasePost):
msg_segments = await self.generate_text_messages()
if msg_segments:
if self.compress:
msgs = [
reduce(lambda x, y: x.append(y), msg_segments, MessageFactory([]))
]
msgs = [reduce(lambda x, y: x.append(y), msg_segments, MessageFactory([]))]
else:
msgs = list(
map(lambda msg_segment: MessageFactory([msg_segment]), msg_segments)
)
msgs = [MessageFactory([msg_segment]) for msg_segment in msg_segments]
else:
msgs = []
msgs.extend(self.extra_msg)

View File

@ -1,19 +1,17 @@
from dataclasses import dataclass, field
from typing import Optional
from dataclasses import field, dataclass
from nonebot.adapters.onebot.v11 import MessageSegment
from nonebot.log import logger
from nonebot.plugin import require
from nonebot_plugin_saa import Image, MessageFactory, MessageSegmentFactory, Text
from nonebot.adapters.onebot.v11 import MessageSegment
from nonebot_plugin_saa import Text, Image, MessageSegmentFactory
from .abstract_post import AbstractPost, BasePost
from .abstract_post import BasePost, AbstractPost
@dataclass
class _CustomPost(BasePost):
ms_factories: list[MessageSegmentFactory] = field(default_factory=list)
css_path: Optional[str] = None # 模板文件所用css路径
css_path: str | None = None # 模板文件所用css路径
async def generate_text_messages(self) -> list[MessageSegmentFactory]:
return self.ms_factories
@ -31,15 +29,13 @@ class _CustomPost(BasePost):
for message_segment in self.ms_factories:
match message_segment:
case Text(data={"text": text}):
md += "{}<br>".format(text)
md += f"{text}<br>"
case Image(data={"image": image}):
# use onebot v11 to convert image into url
ob11_image = MessageSegment.image(image)
md += "![Image]({})\n".format(ob11_image.data["file"])
case _:
logger.warning(
"custom_post不支持处理类型:{}".format(type(message_segment))
)
logger.warning(f"custom_post不支持处理类型:{type(message_segment)}")
continue
return md

View File

@ -1,30 +1,27 @@
from dataclasses import dataclass, field
from functools import reduce
from io import BytesIO
from typing import Optional, Union
from dataclasses import field, dataclass
import nonebot_plugin_saa as saa
from nonebot.log import logger
from nonebot_plugin_saa.utils import MessageFactory, MessageSegmentFactory
from PIL import Image
from nonebot.log import logger
import nonebot_plugin_saa as saa
from nonebot_plugin_saa.utils import MessageSegmentFactory
from ..utils import http_client, parse_text
from .abstract_post import AbstractPost, BasePost, OptionalMixin
from ..utils import parse_text, http_client
from .abstract_post import BasePost, AbstractPost
@dataclass
class _Post(BasePost):
target_type: str
text: str
url: Optional[str] = None
target_name: Optional[str] = None
pics: list[Union[str, bytes]] = field(default_factory=list)
url: str | None = None
target_name: str | None = None
pics: list[str | bytes] = field(default_factory=list)
_message: Optional[list[MessageSegmentFactory]] = None
_pic_message: Optional[list[MessageSegmentFactory]] = None
_message: list[MessageSegmentFactory] | None = None
_pic_message: list[MessageSegmentFactory] | None = None
async def _pic_url_to_image(self, data: Union[str, bytes]) -> Image.Image:
async def _pic_url_to_image(self, data: str | bytes) -> Image.Image:
pic_buffer = BytesIO()
if isinstance(data, str):
async with http_client() as client:
@ -101,22 +98,19 @@ class _Post(BasePost):
self.pics.insert(0, target_io.getvalue())
async def generate_text_messages(self) -> list[MessageSegmentFactory]:
if self._message is None:
await self._pic_merge()
msg_segments: list[MessageSegmentFactory] = []
text = ""
if self.text:
text += "{}".format(
self.text if len(self.text) < 500 else self.text[:500] + "..."
)
text += "{}".format(self.text if len(self.text) < 500 else self.text[:500] + "...")
if text:
text += "\n"
text += "来源: {}".format(self.target_type)
text += f"来源: {self.target_type}"
if self.target_name:
text += " {}".format(self.target_name)
text += f" {self.target_name}"
if self.url:
text += " \n详情: {}".format(self.url)
text += f" \n详情: {self.url}"
msg_segments.append(saa.Text(text))
for pic in self.pics:
msg_segments.append(saa.Image(pic))
@ -124,17 +118,16 @@ class _Post(BasePost):
return self._message
async def generate_pic_messages(self) -> list[MessageSegmentFactory]:
if self._pic_message is None:
await self._pic_merge()
msg_segments: list[MessageSegmentFactory] = []
text = ""
if self.text:
text += "{}".format(self.text)
text += f"{self.text}"
text += "\n"
text += "来源: {}".format(self.target_type)
text += f"来源: {self.target_type}"
if self.target_name:
text += " {}".format(self.target_name)
text += f" {self.target_name}"
msg_segments.append(await parse_text(text))
if not self.target_type == "rss" and self.url:
msg_segments.append(saa.Text(self.url))
@ -149,14 +142,7 @@ class _Post(BasePost):
self.target_name,
self.text if len(self.text) < 500 else self.text[:500] + "...",
self.url,
", ".join(
map(
lambda x: "b64img"
if isinstance(x, bytes) or x.startswith("base64")
else x,
self.pics,
)
),
", ".join("b64img" if isinstance(x, bytes) or x.startswith("base64") else x for x in self.pics),
)

View File

@ -1 +1,3 @@
from .manager import *
from .manager import init_scheduler, scheduler_dict, handle_delete_target, handle_insert_new_target
__all__ = ["init_scheduler", "handle_delete_target", "handle_insert_new_target", "scheduler_dict"]

View File

@ -1,18 +1,16 @@
from typing import Type
from ..config import config
from ..config.db_model import Target
from ..platform import platform_manager
from ..types import Target as T_Target
from ..utils import SchedulerConfig
from .scheduler import Scheduler
from ..utils import SchedulerConfig
from ..config.db_model import Target
from ..types import Target as T_Target
from ..platform import platform_manager
scheduler_dict: dict[Type[SchedulerConfig], Scheduler] = {}
scheduler_dict: dict[type[SchedulerConfig], Scheduler] = {}
async def init_scheduler():
_schedule_class_dict: dict[Type[SchedulerConfig], list[Target]] = {}
_schedule_class_platform_dict: dict[Type[SchedulerConfig], list[str]] = {}
_schedule_class_dict: dict[type[SchedulerConfig], list[Target]] = {}
_schedule_class_platform_dict: dict[type[SchedulerConfig], list[str]] = {}
for platform in platform_manager.values():
scheduler_config = platform.scheduler
if not hasattr(scheduler_config, "name") or not scheduler_config.name:
@ -33,9 +31,7 @@ async def init_scheduler():
for target in target_list:
schedulable_args.append((target.platform_name, T_Target(target.target)))
platform_name_list = _schedule_class_platform_dict[scheduler_config]
scheduler_dict[scheduler_config] = Scheduler(
scheduler_config, schedulable_args, platform_name_list
)
scheduler_dict[scheduler_config] = Scheduler(scheduler_config, schedulable_args, platform_name_list)
config.register_add_target_hook(handle_insert_new_target)
config.register_delete_target_hook(handle_delete_target)

View File

@ -1,14 +1,13 @@
from dataclasses import dataclass
from typing import Optional, Type
from nonebot.log import logger
from nonebot_plugin_apscheduler import scheduler
from nonebot_plugin_saa.utils.exceptions import NoBotFound
from ..config import config
from ..platform import platform_manager
from ..send import send_msgs
from ..types import Target
from ..config import config
from ..send import send_msgs
from ..platform import platform_manager
from ..utils import ProcessContext, SchedulerConfig
@ -20,12 +19,11 @@ class Schedulable:
class Scheduler:
schedulable_list: list[Schedulable]
def __init__(
self,
scheduler_config: Type[SchedulerConfig],
scheduler_config: type[SchedulerConfig],
schedulables: list[tuple[str, Target]],
platform_name_list: list[str],
):
@ -37,15 +35,12 @@ class Scheduler:
self.scheduler_config_obj = self.scheduler_config()
self.schedulable_list = []
for platform_name, target in schedulables:
self.schedulable_list.append(
Schedulable(
platform_name=platform_name, target=target, current_weight=0
)
)
self.schedulable_list.append(Schedulable(platform_name=platform_name, target=target, current_weight=0))
self.platform_name_list = platform_name_list
self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值
logger.info(
f"register scheduler for {self.name} with {self.scheduler_config.schedule_type} {self.scheduler_config.schedule_setting}"
f"register scheduler for {self.name} with "
f"{self.scheduler_config.schedule_type} {self.scheduler_config.schedule_setting}"
)
scheduler.add_job(
self.exec_fetch,
@ -53,7 +48,7 @@ class Scheduler:
**self.scheduler_config.schedule_setting,
)
async def get_next_schedulable(self) -> Optional[Schedulable]:
async def get_next_schedulable(self) -> Schedulable | None:
if not self.schedulable_list:
return None
cur_weight = await config.get_current_weight_val(self.platform_name_list)
@ -61,16 +56,9 @@ class Scheduler:
self.pre_weight_val = 0
cur_max_schedulable = None
for schedulable in self.schedulable_list:
schedulable.current_weight += cur_weight[
f"{schedulable.platform_name}-{schedulable.target}"
]
weight_sum += cur_weight[
f"{schedulable.platform_name}-{schedulable.target}"
]
if (
not cur_max_schedulable
or cur_max_schedulable.current_weight < schedulable.current_weight
):
schedulable.current_weight += cur_weight[f"{schedulable.platform_name}-{schedulable.target}"]
weight_sum += cur_weight[f"{schedulable.platform_name}-{schedulable.target}"]
if not cur_max_schedulable or cur_max_schedulable.current_weight < schedulable.current_weight:
cur_max_schedulable = schedulable
assert cur_max_schedulable
cur_max_schedulable.current_weight -= weight_sum
@ -80,9 +68,7 @@ class Scheduler:
context = ProcessContext()
if not (schedulable := await self.get_next_schedulable()):
return
logger.trace(
f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}"
)
logger.trace(f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}")
send_userinfo_list = await config.get_platform_target_subscribers(
schedulable.platform_name, schedulable.target
)
@ -92,9 +78,7 @@ class Scheduler:
try:
platform_obj = platform_manager[schedulable.platform_name](context, client)
to_send = await platform_obj.do_fetch_new_post(
schedulable.target, send_userinfo_list
)
to_send = await platform_obj.do_fetch_new_post(schedulable.target, send_userinfo_list)
except Exception as err:
records = context.gen_req_records()
for record in records:
@ -107,7 +91,7 @@ class Scheduler:
for user, send_list in to_send:
for send_post in send_list:
logger.info("send to {}: {}".format(user, send_post))
logger.info(f"send to {user}: {send_post}")
try:
await send_msgs(
user,
@ -119,19 +103,14 @@ class Scheduler:
def insert_new_schedulable(self, platform_name: str, target: Target):
self.pre_weight_val += 1000
self.schedulable_list.append(Schedulable(platform_name, target, 1000))
logger.info(
f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})"
)
logger.info(f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})")
def delete_schedulable(self, platform_name, target: Target):
if not self.schedulable_list:
return
to_find_idx = None
for idx, schedulable in enumerate(self.schedulable_list):
if (
schedulable.platform_name == platform_name
and schedulable.target == target
):
if schedulable.platform_name == platform_name and schedulable.target == target:
to_find_idx = idx
break
if to_find_idx is not None:

View File

@ -1,21 +1,23 @@
import importlib
import json
import time
from functools import partial, wraps
import importlib
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, Coroutine, TypeVar
from typing import Any, TypeVar
from functools import wraps, partial
from collections.abc import Callable, Coroutine
from nonebot.log import logger
from ..config.subs_io import subscribes_export, subscribes_import
from ..config.subs_io.nbesf_model import v1, v2
from ..scheduler.manager import init_scheduler
from ..config.subs_io.nbesf_model import v1, v2
from ..config.subs_io import subscribes_export, subscribes_import
try:
from typing_extensions import ParamSpec
import anyio
import click
from typing_extensions import ParamSpec
except ImportError as e: # pragma: no cover
raise ImportError("请使用 `pip install nonebot-bison[cli]` 安装所需依赖") from e
@ -65,9 +67,7 @@ def path_init(ctx, param, value):
@cli.command(help="导出Nonebot Bison Exchangable Subcribes File", name="export")
@click.option(
"--path", "-p", default=None, callback=path_init, help="导出路径, 如果不指定,则默认为工作目录"
)
@click.option("--path", "-p", default=None, callback=path_init, help="导出路径, 如果不指定,则默认为工作目录")
@click.option(
"--format",
default="json",
@ -76,7 +76,6 @@ def path_init(ctx, param, value):
)
@run_async
async def subs_export(path: Path, format: str):
await init_scheduler()
export_file = path / f"bison_subscribes_export_{int(time.time())}.{format}"
@ -121,7 +120,6 @@ async def subs_export(path: Path, format: str):
)
@run_async
async def subs_import(path: str, format: str):
await init_scheduler()
import_file_path = Path(path)

View File

@ -1,17 +1,16 @@
import asyncio
from collections import deque
from typing import Deque
from nonebot.adapters.onebot.v11.exception import ActionFailed
from nonebot.log import logger
from nonebot_plugin_saa import AggregatedMessageFactory, MessageFactory, PlatformTarget
from nonebot.adapters.onebot.v11.exception import ActionFailed
from nonebot_plugin_saa.utils.auto_select_bot import refresh_bots
from nonebot_plugin_saa import MessageFactory, PlatformTarget, AggregatedMessageFactory
from .plugin_config import plugin_config
Sendable = MessageFactory | AggregatedMessageFactory
QUEUE: Deque[tuple[PlatformTarget, Sendable, int]] = deque()
QUEUE: deque[tuple[PlatformTarget, Sendable, int]] = deque()
MESSGE_SEND_INTERVAL = 1.5

View File

@ -1,21 +1,20 @@
import contextlib
from typing import Type, cast
from nonebot.adapters import Message, MessageTemplate
from nonebot.typing import T_State
from nonebot.matcher import Matcher
from nonebot.params import Arg, ArgPlainText
from nonebot.typing import T_State
from nonebot_plugin_saa import PlatformTarget, SupportedAdapters, Text
from nonebot.adapters import Message, MessageTemplate
from nonebot_plugin_saa import Text, PlatformTarget, SupportedAdapters
from ..apis import check_sub_target
from ..config import config
from ..config.db_config import SubscribeDupException
from ..platform import Platform, platform_manager
from ..types import Target
from ..config import config
from ..apis import check_sub_target
from ..platform import Platform, platform_manager
from ..config.db_config import SubscribeDupException
from .utils import common_platform, ensure_user_info, gen_handle_cancel
def do_add_sub(add_sub: Type[Matcher]):
def do_add_sub(add_sub: type[Matcher]):
handle_cancel = gen_handle_cancel(add_sub, "已中止订阅")
add_sub.handle()(ensure_user_info(add_sub))
@ -25,12 +24,7 @@ def do_add_sub(add_sub: Type[Matcher]):
state["_prompt"] = (
"请输入想要订阅的平台,目前支持,请输入冒号左边的名称:\n"
+ "".join(
[
"{}{}\n".format(
platform_name, platform_manager[platform_name].name
)
for platform_name in common_platform
]
[f"{platform_name}: {platform_manager[platform_name].name}\n" for platform_name in common_platform]
)
+ "要查看全部平台请输入:“全部”\n中止订阅过程请输入:“取消”"
)
@ -39,10 +33,7 @@ def do_add_sub(add_sub: Type[Matcher]):
async def parse_platform(state: T_State, platform: str = ArgPlainText()) -> None:
if platform == "全部":
message = "全部平台\n" + "\n".join(
[
"{}{}".format(platform_name, platform.name)
for platform_name, platform in platform_manager.items()
]
[f"{platform_name}: {platform.name}" for platform_name, platform in platform_manager.items()]
)
await add_sub.reject(message)
elif platform == "取消":
@ -57,9 +48,7 @@ def do_add_sub(add_sub: Type[Matcher]):
cur_platform = platform_manager[state["platform"]]
if cur_platform.has_target:
state["_prompt"] = (
("1." + cur_platform.parse_target_promot + "\n2.")
if cur_platform.parse_target_promot
else ""
("1." + cur_platform.parse_target_promot + "\n2.") if cur_platform.parse_target_promot else ""
) + "请输入订阅用户的id\n查询id获取方法请回复:“查询”"
else:
matcher.set_arg("raw_id", None) # type: ignore
@ -81,9 +70,7 @@ def do_add_sub(add_sub: Type[Matcher]):
image = "https://s3.bmp.ovh/imgs/2022/03/ab3cc45d83bd3dd3.jpg"
msg.overwrite(
SupportedAdapters.onebot_v11,
MessageSegment.share(
url=url, title=title, content=content, image=image
),
MessageSegment.share(url=url, title=title, content=content, image=image),
)
await msg.reject()
platform = platform_manager[state["platform"]]
@ -99,14 +86,12 @@ def do_add_sub(add_sub: Type[Matcher]):
await add_sub.reject("id输入错误")
state["id"] = raw_id_text
state["name"] = name
except (Platform.ParseTargetException):
except Platform.ParseTargetException:
await add_sub.reject("不能从你的输入中提取出id请检查你输入的内容是否符合预期")
else:
await add_sub.send(
"即将订阅的用户为:{} {} {}\n如有错误请输入“取消”重新订阅".format(
state["platform"], state["name"], state["id"]
)
)
f"即将订阅的用户为:{state['platform']} {state['name']} {state['id']}\n如有错误请输入“取消”重新订阅"
) # noqa: E501
@add_sub.handle()
async def prepare_get_categories(matcher: Matcher, state: T_State):
@ -125,7 +110,7 @@ def do_add_sub(add_sub: Type[Matcher]):
if platform_manager[state["platform"]].categories:
for cat in raw_cats_text.split():
if cat not in platform_manager[state["platform"]].reverse_category:
await add_sub.reject("不支持 {}".format(cat))
await add_sub.reject(f"不支持 {cat}")
res.append(platform_manager[state["platform"]].reverse_category[cat])
state["cats"] = res
@ -135,14 +120,16 @@ def do_add_sub(add_sub: Type[Matcher]):
matcher.set_arg("raw_tags", None) # type: ignore
state["tags"] = []
return
state["_prompt"] = '请输入要订阅/屏蔽的标签(不含#号)\n多个标签请使用空格隔开\n订阅所有标签输入"全部标签"\n具体规则回复"详情"'
state["_prompt"] = "请输入要订阅/屏蔽的标签(不含#号)\n" "多个标签请使用空格隔开\n" '订阅所有标签输入"全部标签"\n' '具体规则回复"详情"' # noqa: E501
@add_sub.got("raw_tags", MessageTemplate("{_prompt}"), [handle_cancel])
async def parser_tags(state: T_State, raw_tags: Message = Arg()):
raw_tags_text = raw_tags.extract_plain_text()
if raw_tags_text == "详情":
await add_sub.reject(
"订阅标签直接输入标签内容\n屏蔽标签请在标签名称前添加~号\n详见https://nonebot-bison.netlify.app/usage/#%E5%B9%B3%E5%8F%B0%E8%AE%A2%E9%98%85%E6%A0%87%E7%AD%BE-tag"
"订阅标签直接输入标签内容\n"
"屏蔽标签请在标签名称前添加~号\n"
"详见https://nonebot-bison.netlify.app/usage/#%E5%B9%B3%E5%8F%B0%E8%AE%A2%E9%98%85%E6%A0%87%E7%AD%BE-tag"
)
if raw_tags_text in ["全部标签", "全部", "全标签"]:
state["tags"] = []
@ -150,9 +137,7 @@ def do_add_sub(add_sub: Type[Matcher]):
state["tags"] = raw_tags_text.split()
@add_sub.handle()
async def add_sub_process(
state: T_State, user: PlatformTarget = Arg("target_user_info")
):
async def add_sub_process(state: T_State, user: PlatformTarget = Arg("target_user_info")):
try:
await config.add_subscribe(
user=user,

View File

@ -1,26 +1,22 @@
from typing import Type
from nonebot.typing import T_State
from nonebot.matcher import Matcher
from nonebot.params import Arg, EventPlainText
from nonebot.typing import T_State
from nonebot_plugin_saa import MessageFactory, PlatformTarget
from ..config import config
from ..platform import platform_manager
from ..types import Category
from ..utils import parse_text
from ..platform import platform_manager
from .utils import ensure_user_info, gen_handle_cancel
def do_del_sub(del_sub: Type[Matcher]):
def do_del_sub(del_sub: type[Matcher]):
handle_cancel = gen_handle_cancel(del_sub, "删除中止")
del_sub.handle()(ensure_user_info(del_sub))
@del_sub.handle()
async def send_list(
state: T_State, user_info: PlatformTarget = Arg("target_user_info")
):
async def send_list(state: T_State, user_info: PlatformTarget = Arg("target_user_info")):
sub_list = await config.list_subscribe(user_info)
if not sub_list:
await del_sub.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅")
@ -31,22 +27,10 @@ def do_del_sub(del_sub: Type[Matcher]):
"platform_name": sub.target.platform_name,
"target": sub.target.target,
}
res += "{} {} {} {}\n".format(
index,
sub.target.platform_name,
sub.target.target_name,
sub.target.target,
)
res += f"{index} {sub.target.platform_name} {sub.target.target_name} {sub.target.target}\n"
platform = platform_manager[sub.target.platform_name]
if platform.categories:
res += " [{}]".format(
", ".join(
map(
lambda x: platform.categories[Category(x)],
sub.categories,
)
)
)
res += " [{}]".format(", ".join(platform.categories[Category(x)] for x in sub.categories))
if platform.enable_tag:
res += " {}".format(", ".join(sub.tags))
res += "\n"
@ -62,7 +46,7 @@ def do_del_sub(del_sub: Type[Matcher]):
try:
index = int(index_str)
await config.del_subscribe(user_info, **state["sub_table"][index])
except Exception as e:
except Exception:
await del_sub.reject("删除错误")
else:
await del_sub.finish("删除成功")

View File

@ -1,17 +1,15 @@
from typing import Type
from nonebot.matcher import Matcher
from nonebot.params import Arg
from nonebot.matcher import Matcher
from nonebot_plugin_saa import MessageFactory, PlatformTarget
from ..config import config
from ..platform import platform_manager
from ..types import Category
from ..utils import parse_text
from .utils import ensure_user_info
from ..platform import platform_manager
def do_query_sub(query_sub: Type[Matcher]):
def do_query_sub(query_sub: type[Matcher]):
query_sub.handle()(ensure_user_info(query_sub))
@query_sub.handle()
@ -19,19 +17,10 @@ def do_query_sub(query_sub: Type[Matcher]):
sub_list = await config.list_subscribe(user_info)
res = "订阅的帐号为:\n"
for sub in sub_list:
res += "{} {} {}".format(
# sub["target_type"], sub["target_name"], sub["target"]
sub.target.platform_name,
sub.target.target_name,
sub.target.target,
)
res += f"{sub.target.platform_name} {sub.target.target_name} {sub.target.target}"
platform = platform_manager[sub.target.platform_name]
if platform.categories:
res += " [{}]".format(
", ".join(
map(lambda x: platform.categories[Category(x)], sub.categories)
)
)
res += " [{}]".format(", ".join(platform.categories[Category(x)] for x in sub.categories))
if platform.enable_tag:
res += " {}".format(", ".join(sub.tags))
res += "\n"

View File

@ -1,13 +1,13 @@
import contextlib
from typing import Annotated, Type
from typing import Annotated
from nonebot.adapters import Event
from nonebot.matcher import Matcher
from nonebot.params import Depends, EventPlainText, EventToMe
from nonebot.permission import SUPERUSER
from nonebot.rule import Rule
from nonebot.adapters import Event
from nonebot.typing import T_State
from nonebot.matcher import Matcher
from nonebot.permission import SUPERUSER
from nonebot_plugin_saa import extract_target
from nonebot.params import Depends, EventToMe, EventPlainText
from ..platform import platform_manager
from ..plugin_config import plugin_config
@ -31,7 +31,7 @@ common_platform = [
]
def gen_handle_cancel(matcher: Type[Matcher], message: str):
def gen_handle_cancel(matcher: type[Matcher], message: str):
async def _handle_cancel(text: Annotated[str, EventPlainText()]):
if text == "取消":
await matcher.finish(message)
@ -39,12 +39,10 @@ def gen_handle_cancel(matcher: Type[Matcher], message: str):
return Depends(_handle_cancel)
def ensure_user_info(matcher: Type[Matcher]):
def ensure_user_info(matcher: type[Matcher]):
async def _check_user_info(state: T_State):
if not state.get("target_user_info"):
await matcher.finish(
"No target_user_info set, this shouldn't happen, please issue"
)
await matcher.finish("No target_user_info set, this shouldn't happen, please issue")
return _check_user_info

View File

@ -1,17 +1,16 @@
import difflib
import re
import sys
from typing import Union
import nonebot
from bs4 import BeautifulSoup as bs
from nonebot.log import default_format, logger
from nonebot.plugin import require
from nonebot_plugin_saa import Image, MessageSegmentFactory, Text
from bs4 import BeautifulSoup as bs
from nonebot.log import logger, default_format
from nonebot_plugin_saa import Text, Image, MessageSegmentFactory
from ..plugin_config import plugin_config
from .context import ProcessContext
from .http import http_client
from .context import ProcessContext
from ..plugin_config import plugin_config
from .scheduler_config import SchedulerConfig, scheduler
__all__ = [
@ -30,7 +29,7 @@ class Singleton(type):
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
@ -63,7 +62,7 @@ def html_to_text(html: str, query_dict: dict = {}) -> str:
class Filter:
def __init__(self) -> None:
self.level: Union[int, str] = "DEBUG"
self.level: int | str = "DEBUG"
def __call__(self, record):
module_name: str = record["name"]
@ -71,9 +70,7 @@ class Filter:
if module:
module_name = getattr(module, "__module_name__", module_name)
record["name"] = module_name.split(".")[0]
levelno = (
logger.level(self.level).no if isinstance(self.level, str) else self.level
)
levelno = logger.level(self.level).no if isinstance(self.level, str) else self.level
nonebot_warning_level = logger.level("WARNING").no
return (
record["level"].no >= levelno
@ -94,11 +91,7 @@ if plugin_config.bison_filter_log:
)
config = nonebot.get_driver().config
logger.success("Muted info & success from nonebot")
default_filter.level = (
("DEBUG" if config.debug else "INFO")
if config.log_level is None
else config.log_level
)
default_filter.level = ("DEBUG" if config.debug else "INFO") if config.log_level is None else config.log_level
def text_similarity(str1, str2) -> float:

View File

@ -1,6 +1,6 @@
from base64 import b64encode
from httpx import AsyncClient, Response
from httpx import Response, AsyncClient
class ProcessContext:
@ -33,8 +33,13 @@ class ProcessContext:
res = []
for req in self.reqs:
if self._should_print_content(req):
log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} {req.text}"
log_content = (
f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} {req.text}"
)
else:
log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} b64encoded: {b64encode(req.content[:50]).decode()}"
log_content = (
f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} "
f"b64encoded: {b64encode(req.content[:50]).decode()}"
)
res.append(log_content)
return res

View File

@ -1,5 +1,3 @@
import functools
import httpx
from ..plugin_config import plugin_config
@ -10,7 +8,6 @@ http_args = {
http_headers = {"user-agent": plugin_config.bison_ua}
@functools.wraps(httpx.AsyncClient)
def http_client(*args, **kwargs):
if headers := kwargs.get("headers"):
new_headers = http_headers.copy()
@ -18,7 +15,4 @@ def http_client(*args, **kwargs):
kwargs["headers"] = new_headers
else:
kwargs["headers"] = http_headers
return httpx.AsyncClient(*args, **kwargs)
http_client = functools.partial(http_client, **http_args)
return httpx.AsyncClient(*args, **kwargs, **http_args)

View File

@ -1,4 +1,4 @@
from typing import Literal, Type
from typing import Literal
from httpx import AsyncClient
@ -7,7 +7,6 @@ from .http import http_client
class SchedulerConfig:
schedule_type: Literal["date", "interval", "cron"]
schedule_setting: dict
name: str
@ -25,9 +24,7 @@ class SchedulerConfig:
return self.default_http_client
def scheduler(
schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict
) -> Type[SchedulerConfig]:
def scheduler(schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict) -> type[SchedulerConfig]:
return type(
"AnonymousScheduleConfig",
(SchedulerConfig,),

View File

@ -85,7 +85,7 @@ line-length = 120
target-version = "py310"
[tool.black]
line-length = 120
line-length = 118
target-version = ["py310", "py311"]
include = '\.pyi?$'
extend-exclude = '''