🎨 按ruff的检查调整程序代码

This commit is contained in:
Azide 2023-07-16 00:22:20 +08:00 committed by felinae98
parent f232ce4c3e
commit dba8f2a9cb
42 changed files with 414 additions and 757 deletions

View File

@ -6,25 +6,18 @@ require("nonebot_plugin_saa")
import nonebot_plugin_saa import nonebot_plugin_saa
from . import (
admin_page,
bootstrap,
config,
platform,
post,
scheduler,
send,
sub_manager,
types,
utils,
)
from .plugin_config import PlugConfig, plugin_config from .plugin_config import PlugConfig, plugin_config
from . import post, send, types, utils, config, platform, bootstrap, scheduler, admin_page, sub_manager
__help__version__ = "0.7.3" __help__version__ = "0.7.3"
nonebot_plugin_saa.enable_auto_select_bot() nonebot_plugin_saa.enable_auto_select_bot()
__help__plugin__name__ = "nonebot_bison" __help__plugin__name__ = "nonebot_bison"
__usage__ = f"本bot可以提供b站、微博等社交媒体的消息订阅详情请查看本bot文档或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,发送“查询订阅”或“删除订阅”管理订阅" __usage__ = (
"本bot可以提供b站、微博等社交媒体的消息订阅详情请查看本bot文档"
f"或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,"
f"发送“查询订阅”或“删除订阅”管理订阅"
)
__supported_adapters__ = nonebot_plugin_saa.__plugin_meta__.supported_adapters __supported_adapters__ = nonebot_plugin_saa.__plugin_meta__.supported_adapters
@ -41,6 +34,7 @@ __plugin_meta__ = PluginMetadata(
__all__ = [ __all__ = [
"admin_page", "admin_page",
"bootstrap",
"config", "config",
"sub_manager", "sub_manager",
"post", "post",

View File

@ -1,17 +1,16 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Union
from nonebot import get_driver, on_command
from nonebot.adapters.onebot.v11 import Bot
from nonebot.adapters.onebot.v11.event import PrivateMessageEvent
from nonebot.drivers.fastapi import Driver
from nonebot.log import logger from nonebot.log import logger
from nonebot.rule import to_me from nonebot.rule import to_me
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot import get_driver, on_command
from nonebot.drivers.fastapi import Driver
from nonebot.adapters.onebot.v11 import Bot
from nonebot.adapters.onebot.v11.event import PrivateMessageEvent
from ..plugin_config import plugin_config
from .api import router as api_router from .api import router as api_router
from ..plugin_config import plugin_config
from .token_manager import token_manager as tm from .token_manager import token_manager as tm
STATIC_PATH = (Path(__file__).parent / "dist").resolve() STATIC_PATH = (Path(__file__).parent / "dist").resolve()
@ -28,11 +27,9 @@ def init_fastapi():
class SinglePageApplication(StaticFiles): class SinglePageApplication(StaticFiles):
def __init__(self, directory: os.PathLike, index="index.html"): def __init__(self, directory: os.PathLike, index="index.html"):
self.index = index self.index = index
super().__init__( super().__init__(directory=directory, packages=None, html=True, check_dir=True)
directory=directory, packages=None, html=True, check_dir=True
)
def lookup_path(self, path: str) -> tuple[str, Union[os.stat_result, None]]: def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]:
full_path, stat_res = super().lookup_path(path) full_path, stat_res = super().lookup_path(path)
if stat_res is None: if stat_res is None:
return super().lookup_path(self.index) return super().lookup_path(self.index)
@ -45,9 +42,7 @@ def init_fastapi():
description="nonebot-bison webui and api", description="nonebot-bison webui and api",
) )
nonebot_app.include_router(api_router) nonebot_app.include_router(api_router)
nonebot_app.mount( nonebot_app.mount("/", SinglePageApplication(directory=static_path), name="bison-frontend")
"/", SinglePageApplication(directory=static_path), name="bison-frontend"
)
app = driver.server_app app = driver.server_app
app.mount("/bison", nonebot_app, "nonebot-bison") app.mount("/bison", nonebot_app, "nonebot-bison")
@ -63,10 +58,9 @@ def init_fastapi():
if host in ["0.0.0.0", "127.0.0.1"]: if host in ["0.0.0.0", "127.0.0.1"]:
host = "localhost" host = "localhost"
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"Nonebot Bison frontend will be running at: " f"Nonebot Bison frontend will be running at: " f"<b><u>http://{host}:{port}/bison</u></b>"
f"<b><u>http://{host}:{port}/bison</u></b>"
) )
logger.opt(colors=True).info(f"该页面不能被直接访问请私聊bot <b><u>后台管理</u></b> 以获取可访问地址") logger.opt(colors=True).info("该页面不能被直接访问请私聊bot <b><u>后台管理</u></b> 以获取可访问地址")
def register_get_token_handler(): def register_get_token_handler():
@ -93,6 +87,4 @@ if (STATIC_PATH / "index.html").exists():
else: else:
logger.warning("your driver is not fastapi, webui feature will be disabled") logger.warning("your driver is not fastapi, webui feature will be disabled")
else: else:
logger.warning( logger.warning("Frontend file not found, please compile it or use docker or pypi version")
"Frontend file not found, please compile it or use docker or pypi version"
)

View File

@ -1,35 +1,30 @@
import nonebot import nonebot
from fastapi import status from fastapi import status
from fastapi.exceptions import HTTPException
from fastapi.param_functions import Depends
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from fastapi.security.oauth2 import OAuth2PasswordBearer from fastapi.param_functions import Depends
from fastapi.exceptions import HTTPException
from nonebot_plugin_saa import TargetQQGroup from nonebot_plugin_saa import TargetQQGroup
from fastapi.security.oauth2 import OAuth2PasswordBearer
from nonebot_plugin_saa.utils.auto_select_bot import get_bot from nonebot_plugin_saa.utils.auto_select_bot import get_bot
from ..apis import check_sub_target
from ..config import (
NoSuchSubscribeException,
NoSuchTargetException,
NoSuchUserException,
config,
)
from ..config.db_config import SubscribeDupException
from ..platform import platform_manager
from ..types import Target as T_Target
from ..types import WeightConfig from ..types import WeightConfig
from ..utils.get_bot import get_groups from ..apis import check_sub_target
from .jwt import load_jwt, pack_jwt from .jwt import load_jwt, pack_jwt
from ..types import Target as T_Target
from ..utils.get_bot import get_groups
from ..platform import platform_manager
from .token_manager import token_manager from .token_manager import token_manager
from ..config.db_config import SubscribeDupException
from ..config import NoSuchUserException, NoSuchTargetException, NoSuchSubscribeException, config
from .types import ( from .types import (
AddSubscribeReq, TokenResp,
GlobalConf, GlobalConf,
PlatformConfig,
StatusResp, StatusResp,
SubscribeResp,
PlatformConfig,
AddSubscribeReq,
SubscribeConfig, SubscribeConfig,
SubscribeGroupDetail, SubscribeGroupDetail,
SubscribeResp,
TokenResp,
) )
router = APIRouter(prefix="/api", tags=["api"]) router = APIRouter(prefix="/api", tags=["api"])
@ -44,9 +39,7 @@ async def get_jwt_obj(token: str = Depends(oath_scheme)):
return obj return obj
async def check_group_permission( async def check_group_permission(groupNumber: int, token_obj: dict = Depends(get_jwt_obj)):
groupNumber: int, token_obj: dict = Depends(get_jwt_obj)
):
groups = token_obj["groups"] groups = token_obj["groups"]
for group in groups: for group in groups:
if int(groupNumber) == group["id"]: if int(groupNumber) == group["id"]:
@ -95,15 +88,13 @@ async def auth(token: str) -> TokenResp:
jwt_obj = { jwt_obj = {
"id": qq, "id": qq,
"type": "admin", "type": "admin",
"groups": list( "groups": [
map( {
lambda info: {
"id": info["group_id"], "id": info["group_id"],
"name": info["group_name"], "name": info["group_name"],
}, }
await get_groups(), for info in await get_groups()
) ],
),
} }
ret_obj = TokenResp( ret_obj = TokenResp(
type="admin", type="admin",
@ -134,18 +125,16 @@ async def get_subs_info(jwt_obj: dict = Depends(get_jwt_obj)) -> SubscribeResp:
for group in groups: for group in groups:
group_id = group["id"] group_id = group["id"]
raw_subs = await config.list_subscribe(TargetQQGroup(group_id=group_id)) raw_subs = await config.list_subscribe(TargetQQGroup(group_id=group_id))
subs = list( subs = [
map( SubscribeConfig(
lambda sub: SubscribeConfig(
platformName=sub.target.platform_name, platformName=sub.target.platform_name,
targetName=sub.target.target_name, targetName=sub.target.target_name,
cats=sub.categories, cats=sub.categories,
tags=sub.tags, tags=sub.tags,
target=sub.target.target, target=sub.target.target,
),
raw_subs,
)
) )
for sub in raw_subs
]
res[group_id] = SubscribeGroupDetail(name=group["name"], subscribes=subs) res[group_id] = SubscribeGroupDetail(name=group["name"], subscribes=subs)
return res return res
@ -174,9 +163,7 @@ async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp:
@router.delete("/subs", dependencies=[Depends(check_group_permission)]) @router.delete("/subs", dependencies=[Depends(check_group_permission)])
async def del_group_sub(groupNumber: int, platformName: str, target: str): async def del_group_sub(groupNumber: int, platformName: str, target: str):
try: try:
await config.del_subscribe( await config.del_subscribe(TargetQQGroup(group_id=groupNumber), target, platformName)
TargetQQGroup(group_id=groupNumber), target, platformName
)
except (NoSuchUserException, NoSuchSubscribeException): except (NoSuchUserException, NoSuchSubscribeException):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such user or subscribe") raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such user or subscribe")
return StatusResp(ok=True, msg="") return StatusResp(ok=True, msg="")
@ -204,13 +191,9 @@ async def get_weight_config():
@router.put("/weight", dependencies=[Depends(check_is_superuser)]) @router.put("/weight", dependencies=[Depends(check_is_superuser)])
async def update_weigth_config( async def update_weigth_config(platformName: str, target: str, weight_config: WeightConfig):
platformName: str, target: str, weight_config: WeightConfig
):
try: try:
await config.update_time_weight_config( await config.update_time_weight_config(T_Target(target), platformName, weight_config)
T_Target(target), platformName, weight_config
)
except NoSuchTargetException: except NoSuchTargetException:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such subscribe") raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such subscribe")
return StatusResp(ok=True, msg="") return StatusResp(ok=True, msg="")

View File

@ -1,7 +1,6 @@
import datetime
import random import random
import string import string
from typing import Optional import datetime
import jwt import jwt
@ -16,8 +15,8 @@ def pack_jwt(obj: dict) -> str:
) )
def load_jwt(token: str) -> Optional[dict]: def load_jwt(token: str) -> dict | None:
try: try:
return jwt.decode(token, _key, algorithms=["HS256"]) return jwt.decode(token, _key, algorithms=["HS256"])
except: except Exception:
return None return None

View File

@ -1,6 +1,5 @@
import random import random
import string import string
from typing import Optional
from expiringdict import ExpiringDict from expiringdict import ExpiringDict
@ -9,7 +8,7 @@ class TokenManager:
def __init__(self): def __init__(self):
self.token_manager = ExpiringDict(max_len=100, max_age_seconds=60 * 10) self.token_manager = ExpiringDict(max_len=100, max_age_seconds=60 * 10)
def get_user(self, token: str) -> Optional[tuple]: def get_user(self, token: str) -> tuple | None:
res = self.token_manager.get(token) res = self.token_manager.get(token)
assert res is None or isinstance(res, tuple) assert res is None or isinstance(res, tuple)
return res return res

View File

@ -1,2 +1,4 @@
from .db_config import config from .db_config import config as config
from .utils import NoSuchSubscribeException, NoSuchTargetException, NoSuchUserException from .utils import NoSuchUserException as NoSuchUserException
from .utils import NoSuchTargetException as NoSuchTargetException
from .utils import NoSuchSubscribeException as NoSuchSubscribeException

View File

@ -1,20 +1,19 @@
import json
import os import os
from collections import defaultdict import json
from datetime import datetime
from os import path from os import path
from pathlib import Path from pathlib import Path
from typing import DefaultDict, Literal, Mapping, TypedDict from datetime import datetime
from collections import defaultdict
from typing import Literal, TypedDict
import nonebot
from nonebot.log import logger from nonebot.log import logger
from tinydb import Query, TinyDB from tinydb import Query, TinyDB
from ..utils import Singleton
from ..types import User, Target
from ..platform import platform_manager from ..platform import platform_manager
from ..plugin_config import plugin_config from ..plugin_config import plugin_config
from ..types import Target, User from .utils import NoSuchUserException, NoSuchSubscribeException
from ..utils import Singleton
from .utils import NoSuchSubscribeException, NoSuchUserException
supported_target_type = platform_manager.keys() supported_target_type = platform_manager.keys()
@ -89,17 +88,16 @@ class Config(metaclass=Singleton):
self.target_user_cat_cache = {} self.target_user_cat_cache = {}
self.target_user_tag_cache = {} self.target_user_tag_cache = {}
self.target_list = {} self.target_list = {}
self.next_index: DefaultDict[str, int] = defaultdict(lambda: 0) self.next_index: defaultdict[str, int] = defaultdict(lambda: 0)
else: else:
self.available = False self.available = False
def add_subscribe( def add_subscribe(self, user, user_type, target, target_name, target_type, cats, tags):
self, user, user_type, target, target_name, target_type, cats, tags
):
user_query = Query() user_query = Query()
query = (user_query.user == user) & (user_query.user_type == user_type) query = (user_query.user == user) & (user_query.user_type == user_type)
if user_data := self.user_target.get(query): if user_data := self.user_target.get(query):
# update # update
assert not isinstance(user_data, list)
subs: list = user_data.get("subs", []) subs: list = user_data.get("subs", [])
subs.append( subs.append(
{ {
@ -132,9 +130,8 @@ class Config(metaclass=Singleton):
def list_subscribe(self, user, user_type) -> list[SubscribeContent]: def list_subscribe(self, user, user_type) -> list[SubscribeContent]:
query = Query() query = Query()
if user_sub := self.user_target.get( if user_sub := self.user_target.get((query.user == user) & (query.user_type == user_type)):
(query.user == user) & (query.user_type == user_type) assert not isinstance(user_sub, list)
):
return user_sub["subs"] return user_sub["subs"]
return [] return []
@ -146,6 +143,7 @@ class Config(metaclass=Singleton):
query = (user_query.user == user) & (user_query.user_type == user_type) query = (user_query.user == user) & (user_query.user_type == user_type)
if not (query_res := self.user_target.get(query)): if not (query_res := self.user_target.get(query)):
raise NoSuchUserException() raise NoSuchUserException()
assert not isinstance(query_res, list)
subs = query_res.get("subs", []) subs = query_res.get("subs", [])
for idx, sub in enumerate(subs): for idx, sub in enumerate(subs):
if sub.get("target") == target and sub.get("target_type") == target_type: if sub.get("target") == target and sub.get("target_type") == target_type:
@ -155,13 +153,12 @@ class Config(metaclass=Singleton):
return return
raise NoSuchSubscribeException() raise NoSuchSubscribeException()
def update_subscribe( def update_subscribe(self, user, user_type, target, target_name, target_type, cats, tags):
self, user, user_type, target, target_name, target_type, cats, tags
):
user_query = Query() user_query = Query()
query = (user_query.user == user) & (user_query.user_type == user_type) query = (user_query.user == user) & (user_query.user_type == user_type)
if user_data := self.user_target.get(query): if user_data := self.user_target.get(query):
# update # update
assert not isinstance(user_data, list)
subs: list = user_data.get("subs", []) subs: list = user_data.get("subs", [])
find_flag = False find_flag = False
for item in subs: for item in subs:
@ -182,19 +179,13 @@ class Config(metaclass=Singleton):
def update_send_cache(self): def update_send_cache(self):
res = {target_type: defaultdict(list) for target_type in supported_target_type} res = {target_type: defaultdict(list) for target_type in supported_target_type}
cat_res = { cat_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
target_type: defaultdict(lambda: defaultdict(list)) tag_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
for target_type in supported_target_type
}
tag_res = {
target_type: defaultdict(lambda: defaultdict(list))
for target_type in supported_target_type
}
# res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} # res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
to_del = [] to_del = []
for user in self.user_target.all(): for user in self.user_target.all():
for sub in user.get("subs", []): for sub in user.get("subs", []):
if not sub.get("target_type") in supported_target_type: if sub.get("target_type") not in supported_target_type:
to_del.append( to_del.append(
{ {
"user": user["user"], "user": user["user"],
@ -204,36 +195,28 @@ class Config(metaclass=Singleton):
} }
) )
continue continue
res[sub["target_type"]][sub["target"]].append( res[sub["target_type"]][sub["target"]].append(User(user["user"], user["user_type"]))
User(user["user"], user["user_type"]) cat_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[
) "cats"
cat_res[sub["target_type"]][sub["target"]][ ]
"{}-{}".format(user["user_type"], user["user"]) tag_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[
] = sub["cats"] "tags"
tag_res[sub["target_type"]][sub["target"]][ ]
"{}-{}".format(user["user_type"], user["user"])
] = sub["tags"]
self.target_user_cache = res self.target_user_cache = res
self.target_user_cat_cache = cat_res self.target_user_cat_cache = cat_res
self.target_user_tag_cache = tag_res self.target_user_tag_cache = tag_res
for target_type in self.target_user_cache: for target_type in self.target_user_cache:
self.target_list[target_type] = list( self.target_list[target_type] = list(self.target_user_cache[target_type].keys())
self.target_user_cache[target_type].keys()
)
logger.info(f"Deleting {to_del}") logger.info(f"Deleting {to_del}")
for d in to_del: for d in to_del:
self.del_subscribe(**d) self.del_subscribe(**d)
def get_sub_category(self, target_type, target, user_type, user): def get_sub_category(self, target_type, target, user_type, user):
return self.target_user_cat_cache[target_type][target][ return self.target_user_cat_cache[target_type][target][f"{user_type}-{user}"]
"{}-{}".format(user_type, user)
]
def get_sub_tags(self, target_type, target, user_type, user): def get_sub_tags(self, target_type, target, user_type, user):
return self.target_user_tag_cache[target_type][target][ return self.target_user_tag_cache[target_type][target][f"{user_type}-{user}"]
"{}-{}".format(user_type, user)
]
def get_next_target(self, target_type): def get_next_target(self, target_type):
# FIXME 插入或删除target后对队列的影响但是并不是大问题 # FIXME 插入或删除target后对队列的影响但是并不是大问题

View File

@ -1,19 +1,19 @@
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from datetime import datetime, time from datetime import time, datetime
from typing import Awaitable, Callable, Optional, Sequence from collections.abc import Callable, Sequence, Awaitable
from nonebot_plugin_datastore import create_session
from nonebot_plugin_saa import PlatformTarget
from sqlalchemy import delete, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from sqlalchemy.exc import IntegrityError
from sqlalchemy import func, delete, select
from nonebot_plugin_saa import PlatformTarget
from nonebot_plugin_datastore import create_session
from ..types import Category, PlatformWeightConfigResp, Tag from ..types import Tag
from ..types import Target as T_Target from ..types import Target as T_Target
from ..types import TimeWeightConfig, UserSubInfo, WeightConfig
from .db_model import ScheduleTimeWeight, Subscribe, Target, User
from .utils import NoSuchTargetException from .utils import NoSuchTargetException
from .db_model import User, Target, Subscribe, ScheduleTimeWeight
from ..types import Category, UserSubInfo, WeightConfig, TimeWeightConfig, PlatformWeightConfigResp
def _get_time(): def _get_time():
@ -48,23 +48,17 @@ class DBConfig:
): ):
async with create_session() as session: async with create_session() as session:
db_user_stmt = select(User).where(User.user_target == user.dict()) db_user_stmt = select(User).where(User.user_target == user.dict())
db_user: Optional[User] = await session.scalar(db_user_stmt) db_user: User | None = await session.scalar(db_user_stmt)
if not db_user: if not db_user:
db_user = User(user_target=user.dict()) db_user = User(user_target=user.dict())
session.add(db_user) session.add(db_user)
db_target_stmt = ( db_target_stmt = (
select(Target) select(Target).where(Target.platform_name == platform_name).where(Target.target == target)
.where(Target.platform_name == platform_name)
.where(Target.target == target)
) )
db_target: Optional[Target] = await session.scalar(db_target_stmt) db_target: Target | None = await session.scalar(db_target_stmt)
if not db_target: if not db_target:
db_target = Target( db_target = Target(target=target, platform_name=platform_name, target_name=target_name)
target=target, platform_name=platform_name, target_name=target_name await asyncio.gather(*[hook(platform_name, target) for hook in self.add_target_hook])
)
await asyncio.gather(
*[hook(platform_name, target) for hook in self.add_target_hook]
)
else: else:
db_target.target_name = target_name db_target.target_name = target_name
subscribe = Subscribe( subscribe = Subscribe(
@ -96,44 +90,25 @@ class DBConfig:
"""获取数据库中带有user、target信息的subscribe数据""" """获取数据库中带有user、target信息的subscribe数据"""
async with create_session() as session: async with create_session() as session:
query_stmt = ( query_stmt = (
select(Subscribe) select(Subscribe).join(User).options(selectinload(Subscribe.target), selectinload(Subscribe.user))
.join(User)
.options(selectinload(Subscribe.target), selectinload(Subscribe.user))
) )
subs = (await session.scalars(query_stmt)).all() subs = (await session.scalars(query_stmt)).all()
return subs return subs
async def del_subscribe( async def del_subscribe(self, user: PlatformTarget, target: str, platform_name: str):
self, user: PlatformTarget, target: str, platform_name: str
):
async with create_session() as session: async with create_session() as session:
user_obj = await session.scalar( user_obj = await session.scalar(select(User).where(User.user_target == user.dict()))
select(User).where(User.user_target == user.dict())
)
target_obj = await session.scalar( target_obj = await session.scalar(
select(Target).where( select(Target).where(Target.platform_name == platform_name, Target.target == target)
Target.platform_name == platform_name, Target.target == target
)
)
await session.execute(
delete(Subscribe).where(
Subscribe.user == user_obj, Subscribe.target == target_obj
)
) )
await session.execute(delete(Subscribe).where(Subscribe.user == user_obj, Subscribe.target == target_obj))
target_count = await session.scalar( target_count = await session.scalar(
select(func.count()) select(func.count()).select_from(Subscribe).where(Subscribe.target == target_obj)
.select_from(Subscribe)
.where(Subscribe.target == target_obj)
) )
if target_count == 0: if target_count == 0:
# delete empty target # delete empty target
await asyncio.gather( await asyncio.gather(*[hook(platform_name, T_Target(target)) for hook in self.delete_target_hook])
*[
hook(platform_name, T_Target(target))
for hook in self.delete_target_hook
]
)
await session.commit() await session.commit()
async def update_subscribe( async def update_subscribe(
@ -165,29 +140,22 @@ class DBConfig:
async def get_platform_target(self, platform_name: str) -> Sequence[Target]: async def get_platform_target(self, platform_name: str) -> Sequence[Target]:
async with create_session() as sess: async with create_session() as sess:
subq = select(Subscribe.target_id).distinct().subquery() subq = select(Subscribe.target_id).distinct().subquery()
query = ( query = select(Target).join(subq).where(Target.platform_name == platform_name)
select(Target).join(subq).where(Target.platform_name == platform_name)
)
return (await sess.scalars(query)).all() return (await sess.scalars(query)).all()
async def get_time_weight_config( async def get_time_weight_config(self, target: T_Target, platform_name: str) -> WeightConfig:
self, target: T_Target, platform_name: str
) -> WeightConfig:
async with create_session() as sess: async with create_session() as sess:
time_weight_conf = ( time_weight_conf = (
await sess.scalars( await sess.scalars(
select(ScheduleTimeWeight) select(ScheduleTimeWeight)
.where( .where(Target.platform_name == platform_name, Target.target == target)
Target.platform_name == platform_name, Target.target == target
)
.join(Target) .join(Target)
) )
).all() ).all()
targetObj = await sess.scalar( targetObj = await sess.scalar(
select(Target).where( select(Target).where(Target.platform_name == platform_name, Target.target == target)
Target.platform_name == platform_name, Target.target == target
)
) )
assert targetObj
return WeightConfig( return WeightConfig(
default=targetObj.default_schedule_weight, default=targetObj.default_schedule_weight,
time_config=[ time_config=[
@ -200,22 +168,16 @@ class DBConfig:
], ],
) )
async def update_time_weight_config( async def update_time_weight_config(self, target: T_Target, platform_name: str, conf: WeightConfig):
self, target: T_Target, platform_name: str, conf: WeightConfig
):
async with create_session() as sess: async with create_session() as sess:
targetObj = await sess.scalar( targetObj = await sess.scalar(
select(Target).where( select(Target).where(Target.platform_name == platform_name, Target.target == target)
Target.platform_name == platform_name, Target.target == target
)
) )
if not targetObj: if not targetObj:
raise NoSuchTargetException() raise NoSuchTargetException()
target_id = targetObj.id target_id = targetObj.id
targetObj.default_schedule_weight = conf.default targetObj.default_schedule_weight = conf.default
delete_statement = delete(ScheduleTimeWeight).where( delete_statement = delete(ScheduleTimeWeight).where(ScheduleTimeWeight.target_id == target_id)
ScheduleTimeWeight.target_id == target_id
)
await sess.execute(delete_statement) await sess.execute(delete_statement)
for time_conf in conf.time_config: for time_conf in conf.time_config:
new_conf = ScheduleTimeWeight( new_conf = ScheduleTimeWeight(
@ -243,18 +205,13 @@ class DBConfig:
key = f"{target.platform_name}-{target.target}" key = f"{target.platform_name}-{target.target}"
weight = target.default_schedule_weight weight = target.default_schedule_weight
for time_conf in target.time_weight: for time_conf in target.time_weight:
if ( if time_conf.start_time <= cur_time and time_conf.end_time > cur_time:
time_conf.start_time <= cur_time
and time_conf.end_time > cur_time
):
weight = time_conf.weight weight = time_conf.weight
break break
res[key] = weight res[key] = weight
return res return res
async def get_platform_target_subscribers( async def get_platform_target_subscribers(self, platform_name: str, target: T_Target) -> list[UserSubInfo]:
self, platform_name: str, target: T_Target
) -> list[UserSubInfo]:
async with create_session() as sess: async with create_session() as sess:
query = ( query = (
select(Subscribe) select(Subscribe)
@ -263,16 +220,14 @@ class DBConfig:
.options(selectinload(Subscribe.user)) .options(selectinload(Subscribe.user))
) )
subsribes = (await sess.scalars(query)).all() subsribes = (await sess.scalars(query)).all()
return list( return [
map( UserSubInfo(
lambda subscribe: UserSubInfo(
PlatformTarget.deserialize(subscribe.user.user_target), PlatformTarget.deserialize(subscribe.user.user_target),
subscribe.categories, subscribe.categories,
subscribe.tags, subscribe.tags,
),
subsribes,
)
) )
for subscribe in subsribes
]
async def get_all_weight_config( async def get_all_weight_config(
self, self,
@ -281,9 +236,7 @@ class DBConfig:
async with create_session() as sess: async with create_session() as sess:
query = select(Target) query = select(Target)
targets = (await sess.scalars(query)).all() targets = (await sess.scalars(query)).all()
query = select(ScheduleTimeWeight).options( query = select(ScheduleTimeWeight).options(selectinload(ScheduleTimeWeight.target))
selectinload(ScheduleTimeWeight.target)
)
time_weights = (await sess.scalars(query)).all() time_weights = (await sess.scalars(query)).all()
for target in targets: for target in targets:
@ -293,9 +246,7 @@ class DBConfig:
target=T_Target(target.target), target=T_Target(target.target),
target_name=target.target_name, target_name=target.target_name,
platform_name=platform_name, platform_name=platform_name,
weight=WeightConfig( weight=WeightConfig(default=target.default_schedule_weight, time_config=[]),
default=target.default_schedule_weight, time_config=[]
),
) )
for time_weight_config in time_weights: for time_weight_config in time_weights:

View File

@ -1,22 +1,17 @@
from nonebot.log import logger from nonebot.log import logger
from nonebot_plugin_datastore.db import get_engine from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio.session import AsyncSession
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from .db_model import User, Target, Subscribe
from .config_legacy import Config, ConfigContent, drop from .config_legacy import Config, ConfigContent, drop
from .db_model import Subscribe, Target, User
async def data_migrate(): async def data_migrate():
config = Config() config = Config()
if config.available: if config.available:
logger.warning("You are still using legacy db, migrating to sqlite") logger.warning("You are still using legacy db, migrating to sqlite")
all_subs: list[ConfigContent] = list( all_subs: list[ConfigContent] = [ConfigContent(**item) for item in config.get_all_subscribe().all()]
map(
lambda item: ConfigContent(**item),
config.get_all_subscribe().all(),
)
)
async with AsyncSession(get_engine()) as sess: async with AsyncSession(get_engine()) as sess:
user_to_create = [] user_to_create = []
subscribe_to_create = [] subscribe_to_create = []
@ -37,8 +32,7 @@ async def data_migrate():
if key in user_sub_set: if key in user_sub_set:
# a user subscribe a target twice # a user subscribe a target twice
logger.error( logger.error(
f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次," f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次,随机采用了一个订阅" # noqa: E501
"随机采用了一个订阅"
) )
continue continue
user_sub_set.add(key) user_sub_set.add(key)
@ -69,11 +63,7 @@ async def data_migrate():
tags=sub["tags"], tags=sub["tags"],
) )
subscribe_to_create.append(subscribe_obj) subscribe_to_create.append(subscribe_obj)
sess.add_all( sess.add_all(user_to_create + [x[0] for x in platform_target_map.values()] + subscribe_to_create)
user_to_create
+ list(map(lambda x: x[0], platform_target_map.values()))
+ subscribe_to_create
)
await sess.commit() await sess.commit()
drop() drop()
logger.info("migrate success") logger.info("migrate success")

View File

@ -5,7 +5,6 @@ Revises: 5f3370328e44
Create Date: 2023-01-15 19:04:54.987491 Create Date: 2023-01-15 19:04:54.987491
""" """
import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.

View File

@ -5,7 +5,6 @@ Revises: 0571870f5222
Create Date: 2022-03-26 19:46:50.910721 Create Date: 2022-03-26 19:46:50.910721
""" """
import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
@ -18,14 +17,10 @@ depends_on = None
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("subscribe", schema=None) as batch_op: with op.batch_alter_table("subscribe", schema=None) as batch_op:
batch_op.create_unique_constraint( batch_op.create_unique_constraint("unique-subscribe-constraint", ["target_id", "user_id"])
"unique-subscribe-constraint", ["target_id", "user_id"]
)
with op.batch_alter_table("target", schema=None) as batch_op: with op.batch_alter_table("target", schema=None) as batch_op:
batch_op.create_unique_constraint( batch_op.create_unique_constraint("unique-target-constraint", ["target", "platform_name"])
"unique-target-constraint", ["target", "platform_name"]
)
with op.batch_alter_table("user", schema=None) as batch_op: with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"]) batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"])

View File

@ -1,15 +1,14 @@
from abc import ABC from abc import ABC
from nonebot_plugin_saa.utils import AllSupportedPlatformTarget as UserInfo
from pydantic import BaseModel from pydantic import BaseModel
from nonebot_plugin_saa.utils import AllSupportedPlatformTarget as UserInfo
from ....types import Category, Tag from ....types import Tag, Category
class NBESFBase(BaseModel, ABC): class NBESFBase(BaseModel, ABC):
version: int # 表示nbesf格式版本有效版本从1开始 version: int # 表示nbesf格式版本有效版本从1开始
groups: list = list() groups: list = []
class Config: class Config:
orm_mode = True orm_mode = True

View File

@ -1,25 +1,26 @@
from typing import cast
from collections import defaultdict from collections import defaultdict
from typing import Callable, cast from collections.abc import Callable
from nonebot.log import logger
from nonebot_plugin_datastore.db import create_session
from nonebot_plugin_saa import PlatformTarget
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm.strategy_options import selectinload from nonebot.log import logger
from sqlalchemy.sql.selectable import Select from sqlalchemy.sql.selectable import Select
from nonebot_plugin_saa import PlatformTarget
from nonebot_plugin_datastore.db import create_session
from sqlalchemy.orm.strategy_options import selectinload
from ..db_model import Subscribe, User
from .nbesf_model import NBESFBase, v1, v2
from .utils import NBESFVerMatchErr from .utils import NBESFVerMatchErr
from ..db_model import User, Subscribe
from .nbesf_model import NBESFBase, v1, v2
async def subscribes_export(selector: Callable[[Select], Select]) -> v2.SubGroup: async def subscribes_export(selector: Callable[[Select], Select]) -> v2.SubGroup:
""" """
将Bison订阅导出为 Nonebot Bison Exchangable Subscribes File 标准格式的 SubGroup 类型数据 将Bison订阅导出为 Nonebot Bison Exchangable Subscribes File 标准格式的 SubGroup 类型数据
selector: selector:
sqlalchemy Select 对象的操作函数用于限定查询范围 e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group") sqlalchemy Select 对象的操作函数用于限定查询范围
e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group")
""" """
async with create_session() as sess: async with create_session() as sess:
sub_stmt = select(Subscribe).join(User) sub_stmt = select(Subscribe).join(User)

View File

@ -1,23 +1,22 @@
from collections import defaultdict
from importlib import import_module
from pathlib import Path from pathlib import Path
from pkgutil import iter_modules from pkgutil import iter_modules
from typing import DefaultDict, Type from collections import defaultdict
from importlib import import_module
from .platform import Platform, make_no_target_group from .platform import Platform, make_no_target_group
_package_dir = str(Path(__file__).resolve().parent) _package_dir = str(Path(__file__).resolve().parent)
for (_, module_name, _) in iter_modules([_package_dir]): for _, module_name, _ in iter_modules([_package_dir]):
import_module(f"{__name__}.{module_name}") import_module(f"{__name__}.{module_name}")
_platform_list: DefaultDict[str, list[Type[Platform]]] = defaultdict(list) _platform_list: defaultdict[str, list[type[Platform]]] = defaultdict(list)
for _platform in Platform.registry: for _platform in Platform.registry:
if not _platform.enabled: if not _platform.enabled:
continue continue
_platform_list[_platform.platform_name].append(_platform) _platform_list[_platform.platform_name].append(_platform)
platform_manager: dict[str, Type[Platform]] = dict() platform_manager: dict[str, type[Platform]] = {}
for name, platform_list in _platform_list.items(): for name, platform_list in _platform_list.items():
if len(platform_list) == 1: if len(platform_list) == 1:
platform_manager[name] = platform_list[0] platform_manager[name] = platform_list[0]

View File

@ -1,25 +1,23 @@
import json import json
from typing import Any, Optional from typing import Any
from bs4 import BeautifulSoup as bs
from httpx import AsyncClient from httpx import AsyncClient
from nonebot.plugin import require from nonebot.plugin import require
from bs4 import BeautifulSoup as bs
from ..post import Post from ..post import Post
from ..types import Category, RawPost, Target from ..types import Target, RawPost, Category
from ..utils.scheduler_config import SchedulerConfig from ..utils.scheduler_config import SchedulerConfig
from .platform import CategoryNotRecognize, NewMessage, StatusChange from .platform import NewMessage, StatusChange, CategoryNotRecognize
class ArknightsSchedConf(SchedulerConfig): class ArknightsSchedConf(SchedulerConfig):
name = "arknights" name = "arknights"
schedule_type = "interval" schedule_type = "interval"
schedule_setting = {"seconds": 30} schedule_setting = {"seconds": 30}
class Arknights(NewMessage): class Arknights(NewMessage):
categories = {1: "游戏公告"} categories = {1: "游戏公告"}
platform_name = "arknights" platform_name = "arknights"
name = "明日方舟游戏信息" name = "明日方舟游戏信息"
@ -30,9 +28,7 @@ class Arknights(NewMessage):
has_target = False has_target = False
@classmethod @classmethod
async def get_target_name( async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return "明日方舟游戏信息" return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]: async def get_sub_list(self, _) -> list[RawPost]:
@ -92,7 +88,6 @@ class Arknights(NewMessage):
class AkVersion(StatusChange): class AkVersion(StatusChange):
categories = {2: "更新信息"} categories = {2: "更新信息"}
platform_name = "arknights" platform_name = "arknights"
name = "明日方舟游戏信息" name = "明日方舟游戏信息"
@ -103,15 +98,11 @@ class AkVersion(StatusChange):
has_target = False has_target = False
@classmethod @classmethod
async def get_target_name( async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return "明日方舟游戏信息" return "明日方舟游戏信息"
async def get_status(self, _): async def get_status(self, _):
res_ver = await self.client.get( res_ver = await self.client.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version")
"https://ak-conf.hypergryph.com/config/prod/official/IOS/version"
)
res_preanounce = await self.client.get( res_preanounce = await self.client.get(
"https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json" "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json"
) )
@ -121,20 +112,10 @@ class AkVersion(StatusChange):
def compare_status(self, _, old_status, new_status): def compare_status(self, _, old_status, new_status):
res = [] res = []
if ( if old_status.get("preAnnounceType") == 2 and new_status.get("preAnnounceType") == 0:
old_status.get("preAnnounceType") == 2 res.append(Post("arknights", text="登录界面维护公告上线(大概是开始维护了)", target_name="明日方舟更新信息")) # noqa: E501
and new_status.get("preAnnounceType") == 0 elif old_status.get("preAnnounceType") == 0 and new_status.get("preAnnounceType") == 2:
): res.append(Post("arknights", text="登录界面维护公告下线(大概是开服了,冲!)", target_name="明日方舟更新信息")) # noqa: E501
res.append(
Post("arknights", text="登录界面维护公告上线(大概是开始维护了)", target_name="明日方舟更新信息")
)
elif (
old_status.get("preAnnounceType") == 0
and new_status.get("preAnnounceType") == 2
):
res.append(
Post("arknights", text="登录界面维护公告下线(大概是开服了,冲!)", target_name="明日方舟更新信息")
)
if old_status.get("clientVersion") != new_status.get("clientVersion"): if old_status.get("clientVersion") != new_status.get("clientVersion"):
res.append(Post("arknights", text="游戏本体更新(大更新)", target_name="明日方舟更新信息")) res.append(Post("arknights", text="游戏本体更新(大更新)", target_name="明日方舟更新信息"))
if old_status.get("resVersion") != new_status.get("resVersion"): if old_status.get("resVersion") != new_status.get("resVersion"):
@ -149,7 +130,6 @@ class AkVersion(StatusChange):
class MonsterSiren(NewMessage): class MonsterSiren(NewMessage):
categories = {3: "塞壬唱片新闻"} categories = {3: "塞壬唱片新闻"}
platform_name = "arknights" platform_name = "arknights"
name = "明日方舟游戏信息" name = "明日方舟游戏信息"
@ -160,15 +140,11 @@ class MonsterSiren(NewMessage):
has_target = False has_target = False
@classmethod @classmethod
async def get_target_name( async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return "明日方舟游戏信息" return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]: async def get_sub_list(self, _) -> list[RawPost]:
raw_data = await self.client.get( raw_data = await self.client.get("https://monster-siren.hypergryph.com/api/news")
"https://monster-siren.hypergryph.com/api/news"
)
return raw_data.json()["data"]["list"] return raw_data.json()["data"]["list"]
def get_id(self, post: RawPost) -> Any: def get_id(self, post: RawPost) -> Any:
@ -182,14 +158,12 @@ class MonsterSiren(NewMessage):
async def parse(self, raw_post: RawPost) -> Post: async def parse(self, raw_post: RawPost) -> Post:
url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}' url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}'
res = await self.client.get( res = await self.client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}')
f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}'
)
raw_data = res.json() raw_data = res.json()
content = raw_data["data"]["content"] content = raw_data["data"]["content"]
content = content.replace("</p>", "</p>\n") content = content.replace("</p>", "</p>\n")
soup = bs(content, "html.parser") soup = bs(content, "html.parser")
imgs = list(map(lambda x: x["src"], soup("img"))) imgs = [x["src"] for x in soup("img")]
text = f'{raw_post["title"]}\n{soup.text.strip()}' text = f'{raw_post["title"]}\n{soup.text.strip()}'
return Post( return Post(
"monster-siren", "monster-siren",
@ -203,7 +177,6 @@ class MonsterSiren(NewMessage):
class TerraHistoricusComic(NewMessage): class TerraHistoricusComic(NewMessage):
categories = {4: "泰拉记事社漫画"} categories = {4: "泰拉记事社漫画"}
platform_name = "arknights" platform_name = "arknights"
name = "明日方舟游戏信息" name = "明日方舟游戏信息"
@ -214,15 +187,11 @@ class TerraHistoricusComic(NewMessage):
has_target = False has_target = False
@classmethod @classmethod
async def get_target_name( async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return "明日方舟游戏信息" return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]: async def get_sub_list(self, _) -> list[RawPost]:
raw_data = await self.client.get( raw_data = await self.client.get("https://terra-historicus.hypergryph.com/api/recentUpdate")
"https://terra-historicus.hypergryph.com/api/recentUpdate"
)
return raw_data.json()["data"] return raw_data.json()["data"]
def get_id(self, post: RawPost) -> Any: def get_id(self, post: RawPost) -> Any:

View File

@ -1,14 +1,14 @@
import json
import re import re
import json
from typing import Any
from copy import deepcopy from copy import deepcopy
from datetime import datetime, timedelta
from enum import Enum, unique from enum import Enum, unique
from typing import Any, Literal, Optional from typing_extensions import Self
from datetime import datetime, timedelta
from httpx import AsyncClient from httpx import AsyncClient
from nonebot.log import logger from nonebot.log import logger
from pydantic import BaseModel, Field from pydantic import Field, BaseModel
from typing_extensions import Self
from ..post import Post from ..post import Post
from ..types import ApiError, Category, RawPost, Tag, Target from ..types import ApiError, Category, RawPost, Tag, Target
@ -25,9 +25,7 @@ class BilibiliSchedConf(SchedulerConfig):
cookie_expire_time = timedelta(hours=5) cookie_expire_time = timedelta(hours=5)
def __init__(self): def __init__(self):
self._client_refresh_time = datetime( self._client_refresh_time = datetime(year=2000, month=1, day=1) # an expired time
year=2000, month=1, day=1
) # an expired time
super().__init__() super().__init__()
async def _init_session(self): async def _init_session(self):
@ -69,12 +67,8 @@ class Bilibili(NewMessage):
parse_target_promot = "请输入用户主页的链接" parse_target_promot = "请输入用户主页的链接"
@classmethod @classmethod
async def get_target_name( async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
cls, client: AsyncClient, target: Target res = await client.get("https://api.bilibili.com/x/web-interface/card", params={"mid": target})
) -> Optional[str]:
res = await client.get(
"https://api.bilibili.com/x/web-interface/card", params={"mid": target}
)
res.raise_for_status() res.raise_for_status()
res_data = res.json() res_data = res.json()
if res_data["code"]: if res_data["code"]:
@ -129,12 +123,7 @@ class Bilibili(NewMessage):
return self._do_get_category(post_type) return self._do_get_category(post_type)
def get_tags(self, raw_post: RawPost) -> list[Tag]: def get_tags(self, raw_post: RawPost) -> list[Tag]:
return [ return [*(tp["topic_name"] for tp in raw_post["display"]["topic_info"]["topic_details"])]
*map(
lambda tp: tp["topic_name"],
raw_post["display"]["topic_info"]["topic_details"],
)
]
def _get_info(self, post_type: Category, card) -> tuple[str, list]: def _get_info(self, post_type: Category, card) -> tuple[str, list]:
if post_type == 1: if post_type == 1:
@ -178,24 +167,16 @@ class Bilibili(NewMessage):
url = "" url = ""
if post_type == 1: if post_type == 1:
# 一般动态 # 一般动态
url = "https://t.bilibili.com/{}".format( url = "https://t.bilibili.com/{}".format(raw_post["desc"]["dynamic_id_str"])
raw_post["desc"]["dynamic_id_str"]
)
elif post_type == 2: elif post_type == 2:
# 专栏文章 # 专栏文章
url = "https://www.bilibili.com/read/cv{}".format( url = "https://www.bilibili.com/read/cv{}".format(raw_post["desc"]["rid"])
raw_post["desc"]["rid"]
)
elif post_type == 3: elif post_type == 3:
# 视频 # 视频
url = "https://www.bilibili.com/video/{}".format( url = "https://www.bilibili.com/video/{}".format(raw_post["desc"]["bvid"])
raw_post["desc"]["bvid"]
)
elif post_type == 4: elif post_type == 4:
# 纯文字 # 纯文字
url = "https://t.bilibili.com/{}".format( url = "https://t.bilibili.com/{}".format(raw_post["desc"]["dynamic_id_str"])
raw_post["desc"]["dynamic_id_str"]
)
text, pic = self._get_info(post_type, card_content) text, pic = self._get_info(post_type, card_content)
elif post_type == 5: elif post_type == 5:
# 转发 # 转发
@ -261,10 +242,7 @@ class Bilibililive(StatusChange):
def get_live_action(self, old_info: Self) -> "Bilibililive.LiveAction": def get_live_action(self, old_info: Self) -> "Bilibililive.LiveAction":
status = Bilibililive.LiveStatus status = Bilibililive.LiveStatus
action = Bilibililive.LiveAction action = Bilibililive.LiveAction
if ( if old_info.live_status in [status.OFF, status.CYCLE] and self.live_status == status.ON:
old_info.live_status in [status.OFF, status.CYCLE]
and self.live_status == status.ON
):
return action.TURN_ON return action.TURN_ON
elif old_info.live_status == status.ON and self.live_status in [ elif old_info.live_status == status.ON and self.live_status in [
status.OFF, status.OFF,
@ -281,12 +259,8 @@ class Bilibililive(StatusChange):
return action.OFF return action.OFF
@classmethod @classmethod
async def get_target_name( async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
cls, client: AsyncClient, target: Target res = await client.get("https://api.bilibili.com/x/web-interface/card", params={"mid": target})
) -> Optional[str]:
res = await client.get(
"https://api.bilibili.com/x/web-interface/card", params={"mid": target}
)
res_data = json.loads(res.text) res_data = json.loads(res.text)
if res_data["code"]: if res_data["code"]:
return None return None
@ -382,9 +356,7 @@ class BilibiliBangumi(StatusChange):
_url = "https://api.bilibili.com/pgc/review/user" _url = "https://api.bilibili.com/pgc/review/user"
@classmethod @classmethod
async def get_target_name( async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
cls, client: AsyncClient, target: Target
) -> Optional[str]:
res = await client.get(cls._url, params={"media_id": target}) res = await client.get(cls._url, params={"media_id": target})
res_data = res.json() res_data = res.json()
if res_data["code"]: if res_data["code"]:

View File

@ -1,15 +1,14 @@
from typing import Any, Optional from typing import Any
from httpx import AsyncClient from httpx import AsyncClient
from ..post import Post from ..post import Post
from ..types import RawPost, Target
from ..utils import scheduler from ..utils import scheduler
from .platform import NewMessage from .platform import NewMessage
from ..types import Target, RawPost
class FF14(NewMessage): class FF14(NewMessage):
categories = {} categories = {}
platform_name = "ff14" platform_name = "ff14"
name = "最终幻想XIV官方公告" name = "最终幻想XIV官方公告"
@ -21,9 +20,7 @@ class FF14(NewMessage):
has_target = False has_target = False
@classmethod @classmethod
async def get_target_name( async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return "最终幻想XIV官方公告" return "最终幻想XIV官方公告"
async def get_sub_list(self, _) -> list[RawPost]: async def get_sub_list(self, _) -> list[RawPost]:

View File

@ -2,15 +2,15 @@ import re
import time import time
import traceback import traceback
from bs4 import BeautifulSoup, Tag
from httpx import AsyncClient from httpx import AsyncClient
from nonebot.log import logger from nonebot.log import logger
from bs4 import Tag, BeautifulSoup
from nonebot.plugin import require from nonebot.plugin import require
from ..post import Post from ..post import Post
from ..types import Category, RawPost, Target from ..types import Target, RawPost, Category
from ..utils import SchedulerConfig, http_client from ..utils import SchedulerConfig, http_client
from .platform import CategoryNotRecognize, CategoryNotSupport, NewMessage from .platform import NewMessage, CategoryNotSupport, CategoryNotRecognize
class McbbsnewsSchedConf(SchedulerConfig): class McbbsnewsSchedConf(SchedulerConfig):
@ -134,9 +134,9 @@ class McbbsNews(NewMessage):
if categoty_name in category_values: if categoty_name in category_values:
category_id = category_keys[category_values.index(categoty_name)] category_id = category_keys[category_values.index(categoty_name)]
elif categoty_name in known_category_values: elif categoty_name in known_category_values:
raise CategoryNotSupport("McbbsNews订阅暂不支持 {}".format(categoty_name)) raise CategoryNotSupport(f"McbbsNews订阅暂不支持 {categoty_name}")
else: else:
raise CategoryNotRecognize("Mcbbsnews订阅尚未识别 {}".format(categoty_name)) raise CategoryNotRecognize(f"Mcbbsnews订阅尚未识别 {categoty_name}")
return category_id return category_id
async def parse(self, post: RawPost) -> Post: async def parse(self, post: RawPost) -> Post:
@ -170,7 +170,7 @@ class McbbsNews(NewMessage):
一般而言每条新闻的长度都很可观图片生成时间比较喜人 一般而言每条新闻的长度都很可观图片生成时间比较喜人
""" """
require("nonebot_plugin_htmlrender") require("nonebot_plugin_htmlrender")
from nonebot_plugin_htmlrender import capture_element, text_to_pic from nonebot_plugin_htmlrender import text_to_pic, capture_element
try: try:
assert url assert url
@ -181,7 +181,7 @@ class McbbsNews(NewMessage):
device_scale_factor=3, device_scale_factor=3,
) )
assert pic_data assert pic_data
except: except Exception:
err_info = traceback.format_exc() err_info = traceback.format_exc()
logger.warning(f"渲染错误:{err_info}") logger.warning(f"渲染错误:{err_info}")

View File

@ -1,23 +1,21 @@
import re import re
from typing import Any, Optional from typing import Any
from httpx import AsyncClient from httpx import AsyncClient
from ..post import Post from ..post import Post
from ..types import ApiError, RawPost, Target
from ..utils import SchedulerConfig
from .platform import NewMessage from .platform import NewMessage
from ..utils import SchedulerConfig
from ..types import Target, RawPost, ApiError
class NcmSchedConf(SchedulerConfig): class NcmSchedConf(SchedulerConfig):
name = "music.163.com" name = "music.163.com"
schedule_type = "interval" schedule_type = "interval"
schedule_setting = {"minutes": 1} schedule_setting = {"minutes": 1}
class NcmArtist(NewMessage): class NcmArtist(NewMessage):
categories = {} categories = {}
platform_name = "ncm-artist" platform_name = "ncm-artist"
enable_tag = False enable_tag = False
@ -29,11 +27,9 @@ class NcmArtist(NewMessage):
parse_target_promot = "请输入歌手主页包含数字ID的链接" parse_target_promot = "请输入歌手主页包含数字ID的链接"
@classmethod @classmethod
async def get_target_name( async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
cls, client: AsyncClient, target: Target
) -> Optional[str]:
res = await client.get( res = await client.get(
"https://music.163.com/api/artist/albums/{}".format(target), f"https://music.163.com/api/artist/albums/{target}",
headers={"Referer": "https://music.163.com/"}, headers={"Referer": "https://music.163.com/"},
) )
res_data = res.json() res_data = res.json()
@ -45,16 +41,14 @@ class NcmArtist(NewMessage):
async def parse_target(cls, target_text: str) -> Target: async def parse_target(cls, target_text: str) -> Target:
if re.match(r"^\d+$", target_text): if re.match(r"^\d+$", target_text):
return Target(target_text) return Target(target_text)
elif match := re.match( elif match := re.match(r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text):
r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text
):
return Target(match.group(1)) return Target(match.group(1))
else: else:
raise cls.ParseTargetException() raise cls.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]: async def get_sub_list(self, target: Target) -> list[RawPost]:
res = await self.client.get( res = await self.client.get(
"https://music.163.com/api/artist/albums/{}".format(target), f"https://music.163.com/api/artist/albums/{target}",
headers={"Referer": "https://music.163.com/"}, headers={"Referer": "https://music.163.com/"},
) )
res_data = res.json() res_data = res.json()
@ -74,13 +68,10 @@ class NcmArtist(NewMessage):
target_name = raw_post["artist"]["name"] target_name = raw_post["artist"]["name"]
pics = [raw_post["picUrl"]] pics = [raw_post["picUrl"]]
url = "https://music.163.com/#/album?id={}".format(raw_post["id"]) url = "https://music.163.com/#/album?id={}".format(raw_post["id"])
return Post( return Post("ncm-artist", text=text, url=url, pics=pics, target_name=target_name)
"ncm-artist", text=text, url=url, pics=pics, target_name=target_name
)
class NcmRadio(NewMessage): class NcmRadio(NewMessage):
categories = {} categories = {}
platform_name = "ncm-radio" platform_name = "ncm-radio"
enable_tag = False enable_tag = False
@ -92,9 +83,7 @@ class NcmRadio(NewMessage):
parse_target_promot = "请输入主播电台主页包含数字ID的链接" parse_target_promot = "请输入主播电台主页包含数字ID的链接"
@classmethod @classmethod
async def get_target_name( async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
cls, client: AsyncClient, target: Target
) -> Optional[str]:
res = await client.post( res = await client.post(
"http://music.163.com/api/dj/program/byradio", "http://music.163.com/api/dj/program/byradio",
headers={"Referer": "https://music.163.com/"}, headers={"Referer": "https://music.163.com/"},
@ -109,9 +98,7 @@ class NcmRadio(NewMessage):
async def parse_target(cls, target_text: str) -> Target: async def parse_target(cls, target_text: str) -> Target:
if re.match(r"^\d+$", target_text): if re.match(r"^\d+$", target_text):
return Target(target_text) return Target(target_text)
elif match := re.match( elif match := re.match(r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text):
r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text
):
return Target(match.group(1)) return Target(match.group(1))
else: else:
raise cls.ParseTargetException() raise cls.ParseTargetException()

View File

@ -1,29 +1,32 @@
import json
import ssl import ssl
import json
import time import time
import typing import typing
from typing import Any
from dataclasses import dataclass
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from collections.abc import Collection
from typing import Any, Collection, Optional, Type
import httpx import httpx
from httpx import AsyncClient from httpx import AsyncClient
from nonebot.log import logger from nonebot.log import logger
from nonebot_plugin_saa import PlatformTarget from nonebot_plugin_saa import PlatformTarget
from ..plugin_config import plugin_config
from ..post import Post from ..post import Post
from ..types import Category, RawPost, Tag, Target, UserSubInfo from ..plugin_config import plugin_config
from ..utils import ProcessContext, SchedulerConfig from ..utils import ProcessContext, SchedulerConfig
from ..types import Tag, Target, RawPost, Category, UserSubInfo
class CategoryNotSupport(Exception): class CategoryNotSupport(Exception):
"raise in get_category, when you know the category of the post but don't want to support it or don't support its parsing yet" """raise in get_category, when you know the category of the post
but don't want to support it or don't support its parsing yet
"""
class CategoryNotRecognize(Exception): class CategoryNotRecognize(Exception):
"raise in get_category, when you don't know the category of post" """raise in get_category, when you don't know the category of post"""
class RegistryMeta(type): class RegistryMeta(type):
@ -42,7 +45,6 @@ class RegistryMeta(type):
class PlatformMeta(RegistryMeta): class PlatformMeta(RegistryMeta):
categories: dict[Category, str] categories: dict[Category, str]
store: dict[Target, Any] store: dict[Target, Any]
@ -60,8 +62,7 @@ class PlatformABCMeta(PlatformMeta, ABC):
class Platform(metaclass=PlatformABCMeta, base=True): class Platform(metaclass=PlatformABCMeta, base=True):
scheduler: type[SchedulerConfig]
scheduler: Type[SchedulerConfig]
ctx: ProcessContext ctx: ProcessContext
is_common: bool is_common: bool
enabled: bool enabled: bool
@ -70,16 +71,14 @@ class Platform(metaclass=PlatformABCMeta, base=True):
categories: dict[Category, str] categories: dict[Category, str]
enable_tag: bool enable_tag: bool
platform_name: str platform_name: str
parse_target_promot: Optional[str] = None parse_target_promot: str | None = None
registry: list[Type["Platform"]] registry: list[type["Platform"]]
client: AsyncClient client: AsyncClient
reverse_category: dict[str, Category] reverse_category: dict[str, Category]
@classmethod @classmethod
@abstractmethod @abstractmethod
async def get_target_name( async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
cls, client: AsyncClient, target: Target
) -> Optional[str]:
... ...
@abstractmethod @abstractmethod
@ -95,11 +94,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
return await self.fetch_new_post(target, users) return await self.fetch_new_post(target, users)
except httpx.RequestError as err: except httpx.RequestError as err:
if plugin_config.bison_show_network_warning: if plugin_config.bison_show_network_warning:
logger.warning( logger.warning(f"network connection error: {type(err)}, url: {err.request.url}")
"network connection error: {}, url: {}".format(
type(err), err.request.url
)
)
return [] return []
except ssl.SSLError as err: except ssl.SSLError as err:
if plugin_config.bison_show_network_warning: if plugin_config.bison_show_network_warning:
@ -130,7 +125,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
return Target(target_string) return Target(target_string)
@abstractmethod @abstractmethod
def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]: def get_tags(self, raw_post: RawPost) -> Collection[Tag] | None:
"Return Tag list of given RawPost" "Return Tag list of given RawPost"
@classmethod @classmethod
@ -201,9 +196,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
) -> list[tuple[PlatformTarget, list[Post]]]: ) -> list[tuple[PlatformTarget, list[Post]]]:
res: list[tuple[PlatformTarget, list[Post]]] = [] res: list[tuple[PlatformTarget, list[Post]]] = []
for user, cats, required_tags in users: for user, cats, required_tags in users:
user_raw_post = await self.filter_user_custom( user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
new_posts, cats, required_tags
)
user_post: list[Post] = [] user_post: list[Post] = []
for raw_post in user_raw_post: for raw_post in user_raw_post:
user_post.append(await self.do_parse(raw_post)) user_post.append(await self.do_parse(raw_post))
@ -211,7 +204,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
return res return res
@abstractmethod @abstractmethod
def get_category(self, post: RawPost) -> Optional[Category]: def get_category(self, post: RawPost) -> Category | None:
"Return category of given Rawpost" "Return category of given Rawpost"
raise NotImplementedError() raise NotImplementedError()
@ -221,7 +214,7 @@ class MessageProcess(Platform, abstract=True):
def __init__(self, ctx: ProcessContext, client: AsyncClient): def __init__(self, ctx: ProcessContext, client: AsyncClient):
super().__init__(ctx, client) super().__init__(ctx, client)
self.parse_cache: dict[Any, Post] = dict() self.parse_cache: dict[Any, Post] = {}
@abstractmethod @abstractmethod
def get_id(self, post: RawPost) -> Any: def get_id(self, post: RawPost) -> Any:
@ -246,7 +239,7 @@ class MessageProcess(Platform, abstract=True):
"Get post list of the given target" "Get post list of the given target"
@abstractmethod @abstractmethod
def get_date(self, post: RawPost) -> Optional[int]: def get_date(self, post: RawPost) -> int | None:
"Get post timestamp and return, return None if can't get the time" "Get post timestamp and return, return None if can't get the time"
async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]: async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]:
@ -286,9 +279,7 @@ class NewMessage(MessageProcess, abstract=True):
inited: bool inited: bool
exists_posts: set[Any] exists_posts: set[Any]
async def filter_common_with_diff( async def filter_common_with_diff(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]:
self, target: Target, raw_post_list: list[RawPost]
) -> list[RawPost]:
filtered_post = await self.filter_common(raw_post_list) filtered_post = await self.filter_common(raw_post_list)
store = self.get_stored_data(target) or self.MessageStorage(False, set()) store = self.get_stored_data(target) or self.MessageStorage(False, set())
res = [] res = []
@ -297,11 +288,7 @@ class NewMessage(MessageProcess, abstract=True):
for raw_post in filtered_post: for raw_post in filtered_post:
post_id = self.get_id(raw_post) post_id = self.get_id(raw_post)
store.exists_posts.add(post_id) store.exists_posts.add(post_id)
logger.info( logger.info(f"init {self.platform_name}-{target} with {store.exists_posts}")
"init {}-{} with {}".format(
self.platform_name, target, store.exists_posts
)
)
store.inited = True store.inited = True
else: else:
for raw_post in filtered_post: for raw_post in filtered_post:
@ -400,12 +387,11 @@ class SimplePost(MessageProcess, abstract=True):
return res return res
def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]: def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]:
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
class NoTargetGroup(Platform, abstract=True): class NoTargetGroup(Platform, abstract=True):
platform_list: list[Type[Platform]] platform_list: list[type[Platform]]
platform_obj_list: list[Platform] platform_obj_list: list[Platform]
DUMMY_STR = "_DUMMY" DUMMY_STR = "_DUMMY"
@ -418,24 +404,18 @@ def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
for platform in platform_list: for platform in platform_list:
if platform.has_target: if platform.has_target:
raise RuntimeError( raise RuntimeError(f"Platform {platform.name} should have no target")
"Platform {} should have no target".format(platform.name)
)
if name == DUMMY_STR: if name == DUMMY_STR:
name = platform.name name = platform.name
elif name != platform.name: elif name != platform.name:
raise RuntimeError("Platform name for {} not fit".format(platform_name)) raise RuntimeError(f"Platform name for {platform_name} not fit")
platform_category_key_set = set(platform.categories.keys()) platform_category_key_set = set(platform.categories.keys())
if platform_category_key_set & categories_keys: if platform_category_key_set & categories_keys:
raise RuntimeError( raise RuntimeError(f"Platform categories for {platform_name} duplicate")
"Platform categories for {} duplicate".format(platform_name)
)
categories_keys |= platform_category_key_set categories_keys |= platform_category_key_set
categories.update(platform.categories) categories.update(platform.categories)
if platform.scheduler != scheduler: if platform.scheduler != scheduler:
raise RuntimeError( raise RuntimeError(f"Platform scheduler for {platform_name} not fit")
"Platform scheduler for {} not fit".format(platform_name)
)
def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient): def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient):
Platform.__init__(self, ctx, client) Platform.__init__(self, ctx, client)
@ -444,15 +424,13 @@ def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
self.platform_obj_list.append(platform_class(ctx, client)) self.platform_obj_list.append(platform_class(ctx, client))
def __str__(self: "NoTargetGroup") -> str: def __str__(self: "NoTargetGroup") -> str:
return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]" return "[" + " ".join(x.name for x in self.platform_list) + "]"
@classmethod @classmethod
async def get_target_name(cls, client: AsyncClient, target: Target): async def get_target_name(cls, client: AsyncClient, target: Target):
return await platform_list[0].get_target_name(client, target) return await platform_list[0].get_target_name(client, target)
async def fetch_new_post( async def fetch_new_post(self: "NoTargetGroup", target: Target, users: list[UserSubInfo]):
self: "NoTargetGroup", target: Target, users: list[UserSubInfo]
):
res = defaultdict(list) res = defaultdict(list)
for platform in self.platform_obj_list: for platform in self.platform_obj_list:
platform_res = await platform.fetch_new_post(target=target, users=users) platform_res = await platform.fetch_new_post(target=target, users=users)

View File

@ -1,10 +1,10 @@
import calendar import calendar
import time import time
from typing import Any, Optional from typing import Any
import feedparser import feedparser
from bs4 import BeautifulSoup as bs
from httpx import AsyncClient from httpx import AsyncClient
from bs4 import BeautifulSoup as bs
from ..post import Post from ..post import Post
from ..types import RawPost, Target from ..types import RawPost, Target
@ -20,7 +20,6 @@ class RssSchedConf(SchedulerConfig):
class Rss(NewMessage): class Rss(NewMessage):
categories = {} categories = {}
enable_tag = False enable_tag = False
platform_name = "rss" platform_name = "rss"
@ -31,9 +30,7 @@ class Rss(NewMessage):
has_target = True has_target = True
@classmethod @classmethod
async def get_target_name( async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
cls, client: AsyncClient, target: Target
) -> Optional[str]:
res = await client.get(target, timeout=10.0) res = await client.get(target, timeout=10.0)
feed = feedparser.parse(res.text) feed = feedparser.parse(res.text)
return feed["feed"]["title"] return feed["feed"]["title"]
@ -69,7 +66,7 @@ class Rss(NewMessage):
else: else:
text = f"{title}\n\n{desc}" text = f"{title}\n\n{desc}"
pics = list(map(lambda x: x.attrs["src"], soup("img"))) pics = [x.attrs["src"] for x in soup("img")]
if raw_post.get("media_content"): if raw_post.get("media_content"):
for media in raw_post["media_content"]: for media in raw_post["media_content"]:
if media.get("medium") == "image" and media.get("url"): if media.get("medium") == "image" and media.get("url"):

View File

@ -1,17 +1,16 @@
import json
import re import re
from collections.abc import Callable import json
from typing import Any
from datetime import datetime from datetime import datetime
from typing import Any, Optional
from bs4 import BeautifulSoup as bs
from httpx import AsyncClient from httpx import AsyncClient
from nonebot.log import logger from nonebot.log import logger
from bs4 import BeautifulSoup as bs
from ..post import Post from ..post import Post
from ..types import *
from ..utils import SchedulerConfig, http_client
from .platform import NewMessage from .platform import NewMessage
from ..utils import SchedulerConfig, http_client
from ..types import Tag, Target, RawPost, ApiError, Category
class WeiboSchedConf(SchedulerConfig): class WeiboSchedConf(SchedulerConfig):
@ -21,7 +20,6 @@ class WeiboSchedConf(SchedulerConfig):
class Weibo(NewMessage): class Weibo(NewMessage):
categories = { categories = {
1: "转发", 1: "转发",
2: "视频", 2: "视频",
@ -38,13 +36,9 @@ class Weibo(NewMessage):
parse_target_promot = "请输入用户主页包含数字UID的链接" parse_target_promot = "请输入用户主页包含数字UID的链接"
@classmethod @classmethod
async def get_target_name( async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
cls, client: AsyncClient, target: Target
) -> Optional[str]:
param = {"containerid": "100505" + target} param = {"containerid": "100505" + target}
res = await client.get( res = await client.get("https://m.weibo.cn/api/container/getIndex", params=param)
"https://m.weibo.cn/api/container/getIndex", params=param
)
res_dict = json.loads(res.text) res_dict = json.loads(res.text)
if res_dict.get("ok") == 1: if res_dict.get("ok") == 1:
return res_dict["data"]["userInfo"]["screen_name"] return res_dict["data"]["userInfo"]["screen_name"]
@ -63,13 +57,14 @@ class Weibo(NewMessage):
async def get_sub_list(self, target: Target) -> list[RawPost]: async def get_sub_list(self, target: Target) -> list[RawPost]:
params = {"containerid": "107603" + target} params = {"containerid": "107603" + target}
res = await self.client.get( res = await self.client.get("https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0)
"https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0
)
res_data = json.loads(res.text) res_data = json.loads(res.text)
if not res_data["ok"] and res_data["msg"] != "这里还没有内容": if not res_data["ok"] and res_data["msg"] != "这里还没有内容":
raise ApiError(res.request.url) raise ApiError(res.request.url)
custom_filter: Callable[[RawPost], bool] = lambda d: d["card_type"] == 9
def custom_filter(d: RawPost) -> bool:
return d["card_type"] == 9
return list(filter(custom_filter, res_data["data"]["cards"])) return list(filter(custom_filter, res_data["data"]["cards"]))
def get_id(self, post: RawPost) -> Any: def get_id(self, post: RawPost) -> Any:
@ -79,44 +74,32 @@ class Weibo(NewMessage):
return raw_post["card_type"] == 9 return raw_post["card_type"] == 9
def get_date(self, raw_post: RawPost) -> float: def get_date(self, raw_post: RawPost) -> float:
created_time = datetime.strptime( created_time = datetime.strptime(raw_post["mblog"]["created_at"], "%a %b %d %H:%M:%S %z %Y")
raw_post["mblog"]["created_at"], "%a %b %d %H:%M:%S %z %Y"
)
return created_time.timestamp() return created_time.timestamp()
def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]: def get_tags(self, raw_post: RawPost) -> list[Tag] | None:
"Return Tag list of given RawPost" "Return Tag list of given RawPost"
text = raw_post["mblog"]["text"] text = raw_post["mblog"]["text"]
soup = bs(text, "html.parser") soup = bs(text, "html.parser")
res = list( res = [
map( x[1:-1]
lambda x: x[1:-1], for x in filter(
filter(
lambda s: s[0] == "#" and s[-1] == "#", lambda s: s[0] == "#" and s[-1] == "#",
map(lambda x: x.text, soup.find_all("span", class_="surl-text")), (x.text for x in soup.find_all("span", class_="surl-text")),
),
)
)
super_topic_img = soup.find(
"img", src=re.compile(r"timeline_card_small_super_default")
) )
]
super_topic_img = soup.find("img", src=re.compile(r"timeline_card_small_super_default"))
if super_topic_img: if super_topic_img:
try: try:
res.append( res.append(super_topic_img.parent.parent.find("span", class_="surl-text").text + "超话") # type: ignore
super_topic_img.parent.parent.find("span", class_="surl-text").text # type: ignore except Exception:
+ "超话" logger.info(f"super_topic extract error: {text}")
)
except:
logger.info("super_topic extract error: {}".format(text))
return res return res
def get_category(self, raw_post: RawPost) -> Category: def get_category(self, raw_post: RawPost) -> Category:
if raw_post["mblog"].get("retweeted_status"): if raw_post["mblog"].get("retweeted_status"):
return Category(1) return Category(1)
elif ( elif raw_post["mblog"].get("page_info") and raw_post["mblog"]["page_info"].get("type") == "video":
raw_post["mblog"].get("page_info")
and raw_post["mblog"]["page_info"].get("type") == "video"
):
return Category(2) return Category(2)
elif raw_post["mblog"].get("pics"): elif raw_post["mblog"].get("pics"):
return Category(3) return Category(3)
@ -129,7 +112,8 @@ class Weibo(NewMessage):
async def parse(self, raw_post: RawPost) -> Post: async def parse(self, raw_post: RawPost) -> Post:
header = { header = {
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,"
"*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
"accept-language": "zh-CN,zh;q=0.9", "accept-language": "zh-CN,zh;q=0.9",
"authority": "m.weibo.cn", "authority": "m.weibo.cn",
"cache-control": "max-age=0", "cache-control": "max-age=0",
@ -147,26 +131,16 @@ class Weibo(NewMessage):
retweeted = True retweeted = True
pic_num = info["retweeted_status"]["pic_num"] if retweeted else info["pic_num"] pic_num = info["retweeted_status"]["pic_num"] if retweeted else info["pic_num"]
if info["isLongText"] or pic_num > 9: if info["isLongText"] or pic_num > 9:
res = await self.client.get( res = await self.client.get(f"https://m.weibo.cn/detail/{info['mid']}", headers=header)
"https://m.weibo.cn/detail/{}".format(info["mid"]), headers=header
)
try: try:
match = re.search(r'"status": ([\s\S]+),\s+"call"', res.text) match = re.search(r'"status": ([\s\S]+),\s+"call"', res.text)
assert match assert match
full_json_text = match.group(1) full_json_text = match.group(1)
info = json.loads(full_json_text) info = json.loads(full_json_text)
except: except Exception:
logger.info( logger.info(f"detail message error: https://m.weibo.cn/detail/{info['mid']}")
"detail message error: https://m.weibo.cn/detail/{}".format(
info["mid"]
)
)
parsed_text = self._get_text(info["text"]) parsed_text = self._get_text(info["text"])
raw_pics_list = ( raw_pics_list = info["retweeted_status"].get("pics", []) if retweeted else info.get("pics", [])
info["retweeted_status"].get("pics", [])
if retweeted
else info.get("pics", [])
)
pic_urls = [img["large"]["url"] for img in raw_pics_list] pic_urls = [img["large"]["url"] for img in raw_pics_list]
pics = [] pics = []
for pic_url in pic_urls: for pic_url in pic_urls:
@ -174,7 +148,7 @@ class Weibo(NewMessage):
res = await client.get(pic_url) res = await client.get(pic_url)
res.raise_for_status() res.raise_for_status()
pics.append(res.content) pics.append(res.content)
detail_url = "https://weibo.com/{}/{}".format(info["user"]["id"], info["bid"]) detail_url = f"https://weibo.com/{info['user']['id']}/{info['bid']}"
# return parsed_text, detail_url, pic_urls # return parsed_text, detail_url, pic_urls
return Post( return Post(
"weibo", "weibo",

View File

@ -1,11 +1,8 @@
from typing import Optional
import nonebot import nonebot
from pydantic import BaseSettings from pydantic import BaseSettings
class PlugConfig(BaseSettings): class PlugConfig(BaseSettings):
bison_config_path: str = "" bison_config_path: str = ""
bison_use_pic: bool = False bison_use_pic: bool = False
bison_init_filter: bool = True bison_init_filter: bool = True
@ -17,8 +14,12 @@ class PlugConfig(BaseSettings):
bison_use_pic_merge: int = 0 # 多图片时启用图片合并转发(仅限群) bison_use_pic_merge: int = 0 # 多图片时启用图片合并转发(仅限群)
# 0不启用1首条消息单独发送剩余照片合并转发2以及以上所有消息全部合并转发 # 0不启用1首条消息单独发送剩余照片合并转发2以及以上所有消息全部合并转发
bison_resend_times: int = 0 bison_resend_times: int = 0
bison_proxy: Optional[str] bison_proxy: str | None
bison_ua: str = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.103 Safari/537.36" bison_ua: str = (
"Mozilla/5.0 (X11; Linux x86_64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/51.0.2704.103 Safari/537.36"
)
bison_show_network_warning: bool = True bison_show_network_warning: bool = True
class Config: class Config:

View File

@ -1,3 +1,3 @@
from .post import Post from .post import Post
__all__ = ["Post", "CustomPost"] __all__ = ["Post"]

View File

@ -1,7 +1,6 @@
from abc import abstractmethod
from dataclasses import dataclass, field
from functools import reduce from functools import reduce
from typing import Optional from abc import abstractmethod
from dataclasses import field, dataclass
from nonebot_plugin_saa import MessageFactory, MessageSegmentFactory from nonebot_plugin_saa import MessageFactory, MessageSegmentFactory
@ -25,12 +24,12 @@ class BasePost:
class OptionalMixin: class OptionalMixin:
# Because of https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses # Because of https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
override_use_pic: Optional[bool] = None override_use_pic: bool | None = None
compress: bool = False compress: bool = False
extra_msg: list[MessageFactory] = field(default_factory=list) extra_msg: list[MessageFactory] = field(default_factory=list)
def _use_pic(self): def _use_pic(self):
if not self.override_use_pic is None: if self.override_use_pic is not None:
return self.override_use_pic return self.override_use_pic
return plugin_config.bison_use_pic return plugin_config.bison_use_pic
@ -44,13 +43,9 @@ class AbstractPost(OptionalMixin, BasePost):
msg_segments = await self.generate_text_messages() msg_segments = await self.generate_text_messages()
if msg_segments: if msg_segments:
if self.compress: if self.compress:
msgs = [ msgs = [reduce(lambda x, y: x.append(y), msg_segments, MessageFactory([]))]
reduce(lambda x, y: x.append(y), msg_segments, MessageFactory([]))
]
else: else:
msgs = list( msgs = [MessageFactory([msg_segment]) for msg_segment in msg_segments]
map(lambda msg_segment: MessageFactory([msg_segment]), msg_segments)
)
else: else:
msgs = [] msgs = []
msgs.extend(self.extra_msg) msgs.extend(self.extra_msg)

View File

@ -1,19 +1,17 @@
from dataclasses import dataclass, field from dataclasses import field, dataclass
from typing import Optional
from nonebot.adapters.onebot.v11 import MessageSegment
from nonebot.log import logger from nonebot.log import logger
from nonebot.plugin import require from nonebot.plugin import require
from nonebot_plugin_saa import Image, MessageFactory, MessageSegmentFactory, Text from nonebot.adapters.onebot.v11 import MessageSegment
from nonebot_plugin_saa import Text, Image, MessageSegmentFactory
from .abstract_post import AbstractPost, BasePost from .abstract_post import BasePost, AbstractPost
@dataclass @dataclass
class _CustomPost(BasePost): class _CustomPost(BasePost):
ms_factories: list[MessageSegmentFactory] = field(default_factory=list) ms_factories: list[MessageSegmentFactory] = field(default_factory=list)
css_path: Optional[str] = None # 模板文件所用css路径 css_path: str | None = None # 模板文件所用css路径
async def generate_text_messages(self) -> list[MessageSegmentFactory]: async def generate_text_messages(self) -> list[MessageSegmentFactory]:
return self.ms_factories return self.ms_factories
@ -31,15 +29,13 @@ class _CustomPost(BasePost):
for message_segment in self.ms_factories: for message_segment in self.ms_factories:
match message_segment: match message_segment:
case Text(data={"text": text}): case Text(data={"text": text}):
md += "{}<br>".format(text) md += f"{text}<br>"
case Image(data={"image": image}): case Image(data={"image": image}):
# use onebot v11 to convert image into url # use onebot v11 to convert image into url
ob11_image = MessageSegment.image(image) ob11_image = MessageSegment.image(image)
md += "![Image]({})\n".format(ob11_image.data["file"]) md += "![Image]({})\n".format(ob11_image.data["file"])
case _: case _:
logger.warning( logger.warning(f"custom_post不支持处理类型:{type(message_segment)}")
"custom_post不支持处理类型:{}".format(type(message_segment))
)
continue continue
return md return md

View File

@ -1,30 +1,27 @@
from dataclasses import dataclass, field
from functools import reduce
from io import BytesIO from io import BytesIO
from typing import Optional, Union from dataclasses import field, dataclass
import nonebot_plugin_saa as saa
from nonebot.log import logger
from nonebot_plugin_saa.utils import MessageFactory, MessageSegmentFactory
from PIL import Image from PIL import Image
from nonebot.log import logger
import nonebot_plugin_saa as saa
from nonebot_plugin_saa.utils import MessageSegmentFactory
from ..utils import http_client, parse_text from ..utils import parse_text, http_client
from .abstract_post import AbstractPost, BasePost, OptionalMixin from .abstract_post import BasePost, AbstractPost
@dataclass @dataclass
class _Post(BasePost): class _Post(BasePost):
target_type: str target_type: str
text: str text: str
url: Optional[str] = None url: str | None = None
target_name: Optional[str] = None target_name: str | None = None
pics: list[Union[str, bytes]] = field(default_factory=list) pics: list[str | bytes] = field(default_factory=list)
_message: Optional[list[MessageSegmentFactory]] = None _message: list[MessageSegmentFactory] | None = None
_pic_message: Optional[list[MessageSegmentFactory]] = None _pic_message: list[MessageSegmentFactory] | None = None
async def _pic_url_to_image(self, data: Union[str, bytes]) -> Image.Image: async def _pic_url_to_image(self, data: str | bytes) -> Image.Image:
pic_buffer = BytesIO() pic_buffer = BytesIO()
if isinstance(data, str): if isinstance(data, str):
async with http_client() as client: async with http_client() as client:
@ -101,22 +98,19 @@ class _Post(BasePost):
self.pics.insert(0, target_io.getvalue()) self.pics.insert(0, target_io.getvalue())
async def generate_text_messages(self) -> list[MessageSegmentFactory]: async def generate_text_messages(self) -> list[MessageSegmentFactory]:
if self._message is None: if self._message is None:
await self._pic_merge() await self._pic_merge()
msg_segments: list[MessageSegmentFactory] = [] msg_segments: list[MessageSegmentFactory] = []
text = "" text = ""
if self.text: if self.text:
text += "{}".format( text += "{}".format(self.text if len(self.text) < 500 else self.text[:500] + "...")
self.text if len(self.text) < 500 else self.text[:500] + "..."
)
if text: if text:
text += "\n" text += "\n"
text += "来源: {}".format(self.target_type) text += f"来源: {self.target_type}"
if self.target_name: if self.target_name:
text += " {}".format(self.target_name) text += f" {self.target_name}"
if self.url: if self.url:
text += " \n详情: {}".format(self.url) text += f" \n详情: {self.url}"
msg_segments.append(saa.Text(text)) msg_segments.append(saa.Text(text))
for pic in self.pics: for pic in self.pics:
msg_segments.append(saa.Image(pic)) msg_segments.append(saa.Image(pic))
@ -124,17 +118,16 @@ class _Post(BasePost):
return self._message return self._message
async def generate_pic_messages(self) -> list[MessageSegmentFactory]: async def generate_pic_messages(self) -> list[MessageSegmentFactory]:
if self._pic_message is None: if self._pic_message is None:
await self._pic_merge() await self._pic_merge()
msg_segments: list[MessageSegmentFactory] = [] msg_segments: list[MessageSegmentFactory] = []
text = "" text = ""
if self.text: if self.text:
text += "{}".format(self.text) text += f"{self.text}"
text += "\n" text += "\n"
text += "来源: {}".format(self.target_type) text += f"来源: {self.target_type}"
if self.target_name: if self.target_name:
text += " {}".format(self.target_name) text += f" {self.target_name}"
msg_segments.append(await parse_text(text)) msg_segments.append(await parse_text(text))
if not self.target_type == "rss" and self.url: if not self.target_type == "rss" and self.url:
msg_segments.append(saa.Text(self.url)) msg_segments.append(saa.Text(self.url))
@ -149,14 +142,7 @@ class _Post(BasePost):
self.target_name, self.target_name,
self.text if len(self.text) < 500 else self.text[:500] + "...", self.text if len(self.text) < 500 else self.text[:500] + "...",
self.url, self.url,
", ".join( ", ".join("b64img" if isinstance(x, bytes) or x.startswith("base64") else x for x in self.pics),
map(
lambda x: "b64img"
if isinstance(x, bytes) or x.startswith("base64")
else x,
self.pics,
)
),
) )

View File

@ -1 +1,3 @@
from .manager import * from .manager import init_scheduler, scheduler_dict, handle_delete_target, handle_insert_new_target
__all__ = ["init_scheduler", "handle_delete_target", "handle_insert_new_target", "scheduler_dict"]

View File

@ -1,18 +1,16 @@
from typing import Type
from ..config import config from ..config import config
from ..config.db_model import Target
from ..platform import platform_manager
from ..types import Target as T_Target
from ..utils import SchedulerConfig
from .scheduler import Scheduler from .scheduler import Scheduler
from ..utils import SchedulerConfig
from ..config.db_model import Target
from ..types import Target as T_Target
from ..platform import platform_manager
scheduler_dict: dict[Type[SchedulerConfig], Scheduler] = {} scheduler_dict: dict[type[SchedulerConfig], Scheduler] = {}
async def init_scheduler(): async def init_scheduler():
_schedule_class_dict: dict[Type[SchedulerConfig], list[Target]] = {} _schedule_class_dict: dict[type[SchedulerConfig], list[Target]] = {}
_schedule_class_platform_dict: dict[Type[SchedulerConfig], list[str]] = {} _schedule_class_platform_dict: dict[type[SchedulerConfig], list[str]] = {}
for platform in platform_manager.values(): for platform in platform_manager.values():
scheduler_config = platform.scheduler scheduler_config = platform.scheduler
if not hasattr(scheduler_config, "name") or not scheduler_config.name: if not hasattr(scheduler_config, "name") or not scheduler_config.name:
@ -33,9 +31,7 @@ async def init_scheduler():
for target in target_list: for target in target_list:
schedulable_args.append((target.platform_name, T_Target(target.target))) schedulable_args.append((target.platform_name, T_Target(target.target)))
platform_name_list = _schedule_class_platform_dict[scheduler_config] platform_name_list = _schedule_class_platform_dict[scheduler_config]
scheduler_dict[scheduler_config] = Scheduler( scheduler_dict[scheduler_config] = Scheduler(scheduler_config, schedulable_args, platform_name_list)
scheduler_config, schedulable_args, platform_name_list
)
config.register_add_target_hook(handle_insert_new_target) config.register_add_target_hook(handle_insert_new_target)
config.register_delete_target_hook(handle_delete_target) config.register_delete_target_hook(handle_delete_target)

View File

@ -1,14 +1,13 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Type
from nonebot.log import logger from nonebot.log import logger
from nonebot_plugin_apscheduler import scheduler from nonebot_plugin_apscheduler import scheduler
from nonebot_plugin_saa.utils.exceptions import NoBotFound from nonebot_plugin_saa.utils.exceptions import NoBotFound
from ..config import config
from ..platform import platform_manager
from ..send import send_msgs
from ..types import Target from ..types import Target
from ..config import config
from ..send import send_msgs
from ..platform import platform_manager
from ..utils import ProcessContext, SchedulerConfig from ..utils import ProcessContext, SchedulerConfig
@ -20,12 +19,11 @@ class Schedulable:
class Scheduler: class Scheduler:
schedulable_list: list[Schedulable] schedulable_list: list[Schedulable]
def __init__( def __init__(
self, self,
scheduler_config: Type[SchedulerConfig], scheduler_config: type[SchedulerConfig],
schedulables: list[tuple[str, Target]], schedulables: list[tuple[str, Target]],
platform_name_list: list[str], platform_name_list: list[str],
): ):
@ -37,15 +35,12 @@ class Scheduler:
self.scheduler_config_obj = self.scheduler_config() self.scheduler_config_obj = self.scheduler_config()
self.schedulable_list = [] self.schedulable_list = []
for platform_name, target in schedulables: for platform_name, target in schedulables:
self.schedulable_list.append( self.schedulable_list.append(Schedulable(platform_name=platform_name, target=target, current_weight=0))
Schedulable(
platform_name=platform_name, target=target, current_weight=0
)
)
self.platform_name_list = platform_name_list self.platform_name_list = platform_name_list
self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值 self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值
logger.info( logger.info(
f"register scheduler for {self.name} with {self.scheduler_config.schedule_type} {self.scheduler_config.schedule_setting}" f"register scheduler for {self.name} with "
f"{self.scheduler_config.schedule_type} {self.scheduler_config.schedule_setting}"
) )
scheduler.add_job( scheduler.add_job(
self.exec_fetch, self.exec_fetch,
@ -53,7 +48,7 @@ class Scheduler:
**self.scheduler_config.schedule_setting, **self.scheduler_config.schedule_setting,
) )
async def get_next_schedulable(self) -> Optional[Schedulable]: async def get_next_schedulable(self) -> Schedulable | None:
if not self.schedulable_list: if not self.schedulable_list:
return None return None
cur_weight = await config.get_current_weight_val(self.platform_name_list) cur_weight = await config.get_current_weight_val(self.platform_name_list)
@ -61,16 +56,9 @@ class Scheduler:
self.pre_weight_val = 0 self.pre_weight_val = 0
cur_max_schedulable = None cur_max_schedulable = None
for schedulable in self.schedulable_list: for schedulable in self.schedulable_list:
schedulable.current_weight += cur_weight[ schedulable.current_weight += cur_weight[f"{schedulable.platform_name}-{schedulable.target}"]
f"{schedulable.platform_name}-{schedulable.target}" weight_sum += cur_weight[f"{schedulable.platform_name}-{schedulable.target}"]
] if not cur_max_schedulable or cur_max_schedulable.current_weight < schedulable.current_weight:
weight_sum += cur_weight[
f"{schedulable.platform_name}-{schedulable.target}"
]
if (
not cur_max_schedulable
or cur_max_schedulable.current_weight < schedulable.current_weight
):
cur_max_schedulable = schedulable cur_max_schedulable = schedulable
assert cur_max_schedulable assert cur_max_schedulable
cur_max_schedulable.current_weight -= weight_sum cur_max_schedulable.current_weight -= weight_sum
@ -80,9 +68,7 @@ class Scheduler:
context = ProcessContext() context = ProcessContext()
if not (schedulable := await self.get_next_schedulable()): if not (schedulable := await self.get_next_schedulable()):
return return
logger.trace( logger.trace(f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}")
f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}"
)
send_userinfo_list = await config.get_platform_target_subscribers( send_userinfo_list = await config.get_platform_target_subscribers(
schedulable.platform_name, schedulable.target schedulable.platform_name, schedulable.target
) )
@ -92,9 +78,7 @@ class Scheduler:
try: try:
platform_obj = platform_manager[schedulable.platform_name](context, client) platform_obj = platform_manager[schedulable.platform_name](context, client)
to_send = await platform_obj.do_fetch_new_post( to_send = await platform_obj.do_fetch_new_post(schedulable.target, send_userinfo_list)
schedulable.target, send_userinfo_list
)
except Exception as err: except Exception as err:
records = context.gen_req_records() records = context.gen_req_records()
for record in records: for record in records:
@ -107,7 +91,7 @@ class Scheduler:
for user, send_list in to_send: for user, send_list in to_send:
for send_post in send_list: for send_post in send_list:
logger.info("send to {}: {}".format(user, send_post)) logger.info(f"send to {user}: {send_post}")
try: try:
await send_msgs( await send_msgs(
user, user,
@ -119,19 +103,14 @@ class Scheduler:
def insert_new_schedulable(self, platform_name: str, target: Target): def insert_new_schedulable(self, platform_name: str, target: Target):
self.pre_weight_val += 1000 self.pre_weight_val += 1000
self.schedulable_list.append(Schedulable(platform_name, target, 1000)) self.schedulable_list.append(Schedulable(platform_name, target, 1000))
logger.info( logger.info(f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})")
f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})"
)
def delete_schedulable(self, platform_name, target: Target): def delete_schedulable(self, platform_name, target: Target):
if not self.schedulable_list: if not self.schedulable_list:
return return
to_find_idx = None to_find_idx = None
for idx, schedulable in enumerate(self.schedulable_list): for idx, schedulable in enumerate(self.schedulable_list):
if ( if schedulable.platform_name == platform_name and schedulable.target == target:
schedulable.platform_name == platform_name
and schedulable.target == target
):
to_find_idx = idx to_find_idx = idx
break break
if to_find_idx is not None: if to_find_idx is not None:

View File

@ -1,21 +1,23 @@
import importlib
import json import json
import time import time
from functools import partial, wraps import importlib
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import Any, Callable, Coroutine, TypeVar from typing import Any, TypeVar
from functools import wraps, partial
from collections.abc import Callable, Coroutine
from nonebot.log import logger from nonebot.log import logger
from ..config.subs_io import subscribes_export, subscribes_import
from ..config.subs_io.nbesf_model import v1, v2
from ..scheduler.manager import init_scheduler from ..scheduler.manager import init_scheduler
from ..config.subs_io.nbesf_model import v1, v2
from ..config.subs_io import subscribes_export, subscribes_import
try: try:
from typing_extensions import ParamSpec
import anyio import anyio
import click import click
from typing_extensions import ParamSpec
except ImportError as e: # pragma: no cover except ImportError as e: # pragma: no cover
raise ImportError("请使用 `pip install nonebot-bison[cli]` 安装所需依赖") from e raise ImportError("请使用 `pip install nonebot-bison[cli]` 安装所需依赖") from e
@ -65,9 +67,7 @@ def path_init(ctx, param, value):
@cli.command(help="导出Nonebot Bison Exchangable Subcribes File", name="export") @cli.command(help="导出Nonebot Bison Exchangable Subcribes File", name="export")
@click.option( @click.option("--path", "-p", default=None, callback=path_init, help="导出路径, 如果不指定,则默认为工作目录")
"--path", "-p", default=None, callback=path_init, help="导出路径, 如果不指定,则默认为工作目录"
)
@click.option( @click.option(
"--format", "--format",
default="json", default="json",
@ -76,7 +76,6 @@ def path_init(ctx, param, value):
) )
@run_async @run_async
async def subs_export(path: Path, format: str): async def subs_export(path: Path, format: str):
await init_scheduler() await init_scheduler()
export_file = path / f"bison_subscribes_export_{int(time.time())}.{format}" export_file = path / f"bison_subscribes_export_{int(time.time())}.{format}"
@ -121,7 +120,6 @@ async def subs_export(path: Path, format: str):
) )
@run_async @run_async
async def subs_import(path: str, format: str): async def subs_import(path: str, format: str):
await init_scheduler() await init_scheduler()
import_file_path = Path(path) import_file_path = Path(path)

View File

@ -1,17 +1,16 @@
import asyncio import asyncio
from collections import deque from collections import deque
from typing import Deque
from nonebot.adapters.onebot.v11.exception import ActionFailed
from nonebot.log import logger from nonebot.log import logger
from nonebot_plugin_saa import AggregatedMessageFactory, MessageFactory, PlatformTarget from nonebot.adapters.onebot.v11.exception import ActionFailed
from nonebot_plugin_saa.utils.auto_select_bot import refresh_bots from nonebot_plugin_saa.utils.auto_select_bot import refresh_bots
from nonebot_plugin_saa import MessageFactory, PlatformTarget, AggregatedMessageFactory
from .plugin_config import plugin_config from .plugin_config import plugin_config
Sendable = MessageFactory | AggregatedMessageFactory Sendable = MessageFactory | AggregatedMessageFactory
QUEUE: Deque[tuple[PlatformTarget, Sendable, int]] = deque() QUEUE: deque[tuple[PlatformTarget, Sendable, int]] = deque()
MESSGE_SEND_INTERVAL = 1.5 MESSGE_SEND_INTERVAL = 1.5

View File

@ -1,21 +1,20 @@
import contextlib import contextlib
from typing import Type, cast
from nonebot.adapters import Message, MessageTemplate from nonebot.typing import T_State
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.params import Arg, ArgPlainText from nonebot.params import Arg, ArgPlainText
from nonebot.typing import T_State from nonebot.adapters import Message, MessageTemplate
from nonebot_plugin_saa import PlatformTarget, SupportedAdapters, Text from nonebot_plugin_saa import Text, PlatformTarget, SupportedAdapters
from ..apis import check_sub_target
from ..config import config
from ..config.db_config import SubscribeDupException
from ..platform import Platform, platform_manager
from ..types import Target from ..types import Target
from ..config import config
from ..apis import check_sub_target
from ..platform import Platform, platform_manager
from ..config.db_config import SubscribeDupException
from .utils import common_platform, ensure_user_info, gen_handle_cancel from .utils import common_platform, ensure_user_info, gen_handle_cancel
def do_add_sub(add_sub: Type[Matcher]): def do_add_sub(add_sub: type[Matcher]):
handle_cancel = gen_handle_cancel(add_sub, "已中止订阅") handle_cancel = gen_handle_cancel(add_sub, "已中止订阅")
add_sub.handle()(ensure_user_info(add_sub)) add_sub.handle()(ensure_user_info(add_sub))
@ -25,12 +24,7 @@ def do_add_sub(add_sub: Type[Matcher]):
state["_prompt"] = ( state["_prompt"] = (
"请输入想要订阅的平台,目前支持,请输入冒号左边的名称:\n" "请输入想要订阅的平台,目前支持,请输入冒号左边的名称:\n"
+ "".join( + "".join(
[ [f"{platform_name}: {platform_manager[platform_name].name}\n" for platform_name in common_platform]
"{}{}\n".format(
platform_name, platform_manager[platform_name].name
)
for platform_name in common_platform
]
) )
+ "要查看全部平台请输入:“全部”\n中止订阅过程请输入:“取消”" + "要查看全部平台请输入:“全部”\n中止订阅过程请输入:“取消”"
) )
@ -39,10 +33,7 @@ def do_add_sub(add_sub: Type[Matcher]):
async def parse_platform(state: T_State, platform: str = ArgPlainText()) -> None: async def parse_platform(state: T_State, platform: str = ArgPlainText()) -> None:
if platform == "全部": if platform == "全部":
message = "全部平台\n" + "\n".join( message = "全部平台\n" + "\n".join(
[ [f"{platform_name}: {platform.name}" for platform_name, platform in platform_manager.items()]
"{}{}".format(platform_name, platform.name)
for platform_name, platform in platform_manager.items()
]
) )
await add_sub.reject(message) await add_sub.reject(message)
elif platform == "取消": elif platform == "取消":
@ -57,9 +48,7 @@ def do_add_sub(add_sub: Type[Matcher]):
cur_platform = platform_manager[state["platform"]] cur_platform = platform_manager[state["platform"]]
if cur_platform.has_target: if cur_platform.has_target:
state["_prompt"] = ( state["_prompt"] = (
("1." + cur_platform.parse_target_promot + "\n2.") ("1." + cur_platform.parse_target_promot + "\n2.") if cur_platform.parse_target_promot else ""
if cur_platform.parse_target_promot
else ""
) + "请输入订阅用户的id\n查询id获取方法请回复:“查询”" ) + "请输入订阅用户的id\n查询id获取方法请回复:“查询”"
else: else:
matcher.set_arg("raw_id", None) # type: ignore matcher.set_arg("raw_id", None) # type: ignore
@ -81,9 +70,7 @@ def do_add_sub(add_sub: Type[Matcher]):
image = "https://s3.bmp.ovh/imgs/2022/03/ab3cc45d83bd3dd3.jpg" image = "https://s3.bmp.ovh/imgs/2022/03/ab3cc45d83bd3dd3.jpg"
msg.overwrite( msg.overwrite(
SupportedAdapters.onebot_v11, SupportedAdapters.onebot_v11,
MessageSegment.share( MessageSegment.share(url=url, title=title, content=content, image=image),
url=url, title=title, content=content, image=image
),
) )
await msg.reject() await msg.reject()
platform = platform_manager[state["platform"]] platform = platform_manager[state["platform"]]
@ -99,14 +86,12 @@ def do_add_sub(add_sub: Type[Matcher]):
await add_sub.reject("id输入错误") await add_sub.reject("id输入错误")
state["id"] = raw_id_text state["id"] = raw_id_text
state["name"] = name state["name"] = name
except (Platform.ParseTargetException): except Platform.ParseTargetException:
await add_sub.reject("不能从你的输入中提取出id请检查你输入的内容是否符合预期") await add_sub.reject("不能从你的输入中提取出id请检查你输入的内容是否符合预期")
else: else:
await add_sub.send( await add_sub.send(
"即将订阅的用户为:{} {} {}\n如有错误请输入“取消”重新订阅".format( f"即将订阅的用户为:{state['platform']} {state['name']} {state['id']}\n如有错误请输入“取消”重新订阅"
state["platform"], state["name"], state["id"] ) # noqa: E501
)
)
@add_sub.handle() @add_sub.handle()
async def prepare_get_categories(matcher: Matcher, state: T_State): async def prepare_get_categories(matcher: Matcher, state: T_State):
@ -125,7 +110,7 @@ def do_add_sub(add_sub: Type[Matcher]):
if platform_manager[state["platform"]].categories: if platform_manager[state["platform"]].categories:
for cat in raw_cats_text.split(): for cat in raw_cats_text.split():
if cat not in platform_manager[state["platform"]].reverse_category: if cat not in platform_manager[state["platform"]].reverse_category:
await add_sub.reject("不支持 {}".format(cat)) await add_sub.reject(f"不支持 {cat}")
res.append(platform_manager[state["platform"]].reverse_category[cat]) res.append(platform_manager[state["platform"]].reverse_category[cat])
state["cats"] = res state["cats"] = res
@ -135,14 +120,16 @@ def do_add_sub(add_sub: Type[Matcher]):
matcher.set_arg("raw_tags", None) # type: ignore matcher.set_arg("raw_tags", None) # type: ignore
state["tags"] = [] state["tags"] = []
return return
state["_prompt"] = '请输入要订阅/屏蔽的标签(不含#号)\n多个标签请使用空格隔开\n订阅所有标签输入"全部标签"\n具体规则回复"详情"' state["_prompt"] = "请输入要订阅/屏蔽的标签(不含#号)\n" "多个标签请使用空格隔开\n" '订阅所有标签输入"全部标签"\n' '具体规则回复"详情"' # noqa: E501
@add_sub.got("raw_tags", MessageTemplate("{_prompt}"), [handle_cancel]) @add_sub.got("raw_tags", MessageTemplate("{_prompt}"), [handle_cancel])
async def parser_tags(state: T_State, raw_tags: Message = Arg()): async def parser_tags(state: T_State, raw_tags: Message = Arg()):
raw_tags_text = raw_tags.extract_plain_text() raw_tags_text = raw_tags.extract_plain_text()
if raw_tags_text == "详情": if raw_tags_text == "详情":
await add_sub.reject( await add_sub.reject(
"订阅标签直接输入标签内容\n屏蔽标签请在标签名称前添加~号\n详见https://nonebot-bison.netlify.app/usage/#%E5%B9%B3%E5%8F%B0%E8%AE%A2%E9%98%85%E6%A0%87%E7%AD%BE-tag" "订阅标签直接输入标签内容\n"
"屏蔽标签请在标签名称前添加~号\n"
"详见https://nonebot-bison.netlify.app/usage/#%E5%B9%B3%E5%8F%B0%E8%AE%A2%E9%98%85%E6%A0%87%E7%AD%BE-tag"
) )
if raw_tags_text in ["全部标签", "全部", "全标签"]: if raw_tags_text in ["全部标签", "全部", "全标签"]:
state["tags"] = [] state["tags"] = []
@ -150,9 +137,7 @@ def do_add_sub(add_sub: Type[Matcher]):
state["tags"] = raw_tags_text.split() state["tags"] = raw_tags_text.split()
@add_sub.handle() @add_sub.handle()
async def add_sub_process( async def add_sub_process(state: T_State, user: PlatformTarget = Arg("target_user_info")):
state: T_State, user: PlatformTarget = Arg("target_user_info")
):
try: try:
await config.add_subscribe( await config.add_subscribe(
user=user, user=user,

View File

@ -1,26 +1,22 @@
from typing import Type from nonebot.typing import T_State
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.params import Arg, EventPlainText from nonebot.params import Arg, EventPlainText
from nonebot.typing import T_State
from nonebot_plugin_saa import MessageFactory, PlatformTarget from nonebot_plugin_saa import MessageFactory, PlatformTarget
from ..config import config from ..config import config
from ..platform import platform_manager
from ..types import Category from ..types import Category
from ..utils import parse_text from ..utils import parse_text
from ..platform import platform_manager
from .utils import ensure_user_info, gen_handle_cancel from .utils import ensure_user_info, gen_handle_cancel
def do_del_sub(del_sub: Type[Matcher]): def do_del_sub(del_sub: type[Matcher]):
handle_cancel = gen_handle_cancel(del_sub, "删除中止") handle_cancel = gen_handle_cancel(del_sub, "删除中止")
del_sub.handle()(ensure_user_info(del_sub)) del_sub.handle()(ensure_user_info(del_sub))
@del_sub.handle() @del_sub.handle()
async def send_list( async def send_list(state: T_State, user_info: PlatformTarget = Arg("target_user_info")):
state: T_State, user_info: PlatformTarget = Arg("target_user_info")
):
sub_list = await config.list_subscribe(user_info) sub_list = await config.list_subscribe(user_info)
if not sub_list: if not sub_list:
await del_sub.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅") await del_sub.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅")
@ -31,22 +27,10 @@ def do_del_sub(del_sub: Type[Matcher]):
"platform_name": sub.target.platform_name, "platform_name": sub.target.platform_name,
"target": sub.target.target, "target": sub.target.target,
} }
res += "{} {} {} {}\n".format( res += f"{index} {sub.target.platform_name} {sub.target.target_name} {sub.target.target}\n"
index,
sub.target.platform_name,
sub.target.target_name,
sub.target.target,
)
platform = platform_manager[sub.target.platform_name] platform = platform_manager[sub.target.platform_name]
if platform.categories: if platform.categories:
res += " [{}]".format( res += " [{}]".format(", ".join(platform.categories[Category(x)] for x in sub.categories))
", ".join(
map(
lambda x: platform.categories[Category(x)],
sub.categories,
)
)
)
if platform.enable_tag: if platform.enable_tag:
res += " {}".format(", ".join(sub.tags)) res += " {}".format(", ".join(sub.tags))
res += "\n" res += "\n"
@ -62,7 +46,7 @@ def do_del_sub(del_sub: Type[Matcher]):
try: try:
index = int(index_str) index = int(index_str)
await config.del_subscribe(user_info, **state["sub_table"][index]) await config.del_subscribe(user_info, **state["sub_table"][index])
except Exception as e: except Exception:
await del_sub.reject("删除错误") await del_sub.reject("删除错误")
else: else:
await del_sub.finish("删除成功") await del_sub.finish("删除成功")

View File

@ -1,17 +1,15 @@
from typing import Type
from nonebot.matcher import Matcher
from nonebot.params import Arg from nonebot.params import Arg
from nonebot.matcher import Matcher
from nonebot_plugin_saa import MessageFactory, PlatformTarget from nonebot_plugin_saa import MessageFactory, PlatformTarget
from ..config import config from ..config import config
from ..platform import platform_manager
from ..types import Category from ..types import Category
from ..utils import parse_text from ..utils import parse_text
from .utils import ensure_user_info from .utils import ensure_user_info
from ..platform import platform_manager
def do_query_sub(query_sub: Type[Matcher]): def do_query_sub(query_sub: type[Matcher]):
query_sub.handle()(ensure_user_info(query_sub)) query_sub.handle()(ensure_user_info(query_sub))
@query_sub.handle() @query_sub.handle()
@ -19,19 +17,10 @@ def do_query_sub(query_sub: Type[Matcher]):
sub_list = await config.list_subscribe(user_info) sub_list = await config.list_subscribe(user_info)
res = "订阅的帐号为:\n" res = "订阅的帐号为:\n"
for sub in sub_list: for sub in sub_list:
res += "{} {} {}".format( res += f"{sub.target.platform_name} {sub.target.target_name} {sub.target.target}"
# sub["target_type"], sub["target_name"], sub["target"]
sub.target.platform_name,
sub.target.target_name,
sub.target.target,
)
platform = platform_manager[sub.target.platform_name] platform = platform_manager[sub.target.platform_name]
if platform.categories: if platform.categories:
res += " [{}]".format( res += " [{}]".format(", ".join(platform.categories[Category(x)] for x in sub.categories))
", ".join(
map(lambda x: platform.categories[Category(x)], sub.categories)
)
)
if platform.enable_tag: if platform.enable_tag:
res += " {}".format(", ".join(sub.tags)) res += " {}".format(", ".join(sub.tags))
res += "\n" res += "\n"

View File

@ -1,13 +1,13 @@
import contextlib import contextlib
from typing import Annotated, Type from typing import Annotated
from nonebot.adapters import Event
from nonebot.matcher import Matcher
from nonebot.params import Depends, EventPlainText, EventToMe
from nonebot.permission import SUPERUSER
from nonebot.rule import Rule from nonebot.rule import Rule
from nonebot.adapters import Event
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.matcher import Matcher
from nonebot.permission import SUPERUSER
from nonebot_plugin_saa import extract_target from nonebot_plugin_saa import extract_target
from nonebot.params import Depends, EventToMe, EventPlainText
from ..platform import platform_manager from ..platform import platform_manager
from ..plugin_config import plugin_config from ..plugin_config import plugin_config
@ -31,7 +31,7 @@ common_platform = [
] ]
def gen_handle_cancel(matcher: Type[Matcher], message: str): def gen_handle_cancel(matcher: type[Matcher], message: str):
async def _handle_cancel(text: Annotated[str, EventPlainText()]): async def _handle_cancel(text: Annotated[str, EventPlainText()]):
if text == "取消": if text == "取消":
await matcher.finish(message) await matcher.finish(message)
@ -39,12 +39,10 @@ def gen_handle_cancel(matcher: Type[Matcher], message: str):
return Depends(_handle_cancel) return Depends(_handle_cancel)
def ensure_user_info(matcher: Type[Matcher]): def ensure_user_info(matcher: type[Matcher]):
async def _check_user_info(state: T_State): async def _check_user_info(state: T_State):
if not state.get("target_user_info"): if not state.get("target_user_info"):
await matcher.finish( await matcher.finish("No target_user_info set, this shouldn't happen, please issue")
"No target_user_info set, this shouldn't happen, please issue"
)
return _check_user_info return _check_user_info

View File

@ -1,17 +1,16 @@
import difflib import difflib
import re import re
import sys import sys
from typing import Union
import nonebot import nonebot
from bs4 import BeautifulSoup as bs
from nonebot.log import default_format, logger
from nonebot.plugin import require from nonebot.plugin import require
from nonebot_plugin_saa import Image, MessageSegmentFactory, Text from bs4 import BeautifulSoup as bs
from nonebot.log import logger, default_format
from nonebot_plugin_saa import Text, Image, MessageSegmentFactory
from ..plugin_config import plugin_config
from .context import ProcessContext
from .http import http_client from .http import http_client
from .context import ProcessContext
from ..plugin_config import plugin_config
from .scheduler_config import SchedulerConfig, scheduler from .scheduler_config import SchedulerConfig, scheduler
__all__ = [ __all__ = [
@ -30,7 +29,7 @@ class Singleton(type):
def __call__(cls, *args, **kwargs): def __call__(cls, *args, **kwargs):
if cls not in cls._instances: if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls] return cls._instances[cls]
@ -63,7 +62,7 @@ def html_to_text(html: str, query_dict: dict = {}) -> str:
class Filter: class Filter:
def __init__(self) -> None: def __init__(self) -> None:
self.level: Union[int, str] = "DEBUG" self.level: int | str = "DEBUG"
def __call__(self, record): def __call__(self, record):
module_name: str = record["name"] module_name: str = record["name"]
@ -71,9 +70,7 @@ class Filter:
if module: if module:
module_name = getattr(module, "__module_name__", module_name) module_name = getattr(module, "__module_name__", module_name)
record["name"] = module_name.split(".")[0] record["name"] = module_name.split(".")[0]
levelno = ( levelno = logger.level(self.level).no if isinstance(self.level, str) else self.level
logger.level(self.level).no if isinstance(self.level, str) else self.level
)
nonebot_warning_level = logger.level("WARNING").no nonebot_warning_level = logger.level("WARNING").no
return ( return (
record["level"].no >= levelno record["level"].no >= levelno
@ -94,11 +91,7 @@ if plugin_config.bison_filter_log:
) )
config = nonebot.get_driver().config config = nonebot.get_driver().config
logger.success("Muted info & success from nonebot") logger.success("Muted info & success from nonebot")
default_filter.level = ( default_filter.level = ("DEBUG" if config.debug else "INFO") if config.log_level is None else config.log_level
("DEBUG" if config.debug else "INFO")
if config.log_level is None
else config.log_level
)
def text_similarity(str1, str2) -> float: def text_similarity(str1, str2) -> float:

View File

@ -1,6 +1,6 @@
from base64 import b64encode from base64 import b64encode
from httpx import AsyncClient, Response from httpx import Response, AsyncClient
class ProcessContext: class ProcessContext:
@ -33,8 +33,13 @@ class ProcessContext:
res = [] res = []
for req in self.reqs: for req in self.reqs:
if self._should_print_content(req): if self._should_print_content(req):
log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} {req.text}" log_content = (
f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} {req.text}"
)
else: else:
log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} b64encoded: {b64encode(req.content[:50]).decode()}" log_content = (
f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} "
f"b64encoded: {b64encode(req.content[:50]).decode()}"
)
res.append(log_content) res.append(log_content)
return res return res

View File

@ -1,5 +1,3 @@
import functools
import httpx import httpx
from ..plugin_config import plugin_config from ..plugin_config import plugin_config
@ -10,7 +8,6 @@ http_args = {
http_headers = {"user-agent": plugin_config.bison_ua} http_headers = {"user-agent": plugin_config.bison_ua}
@functools.wraps(httpx.AsyncClient)
def http_client(*args, **kwargs): def http_client(*args, **kwargs):
if headers := kwargs.get("headers"): if headers := kwargs.get("headers"):
new_headers = http_headers.copy() new_headers = http_headers.copy()
@ -18,7 +15,4 @@ def http_client(*args, **kwargs):
kwargs["headers"] = new_headers kwargs["headers"] = new_headers
else: else:
kwargs["headers"] = http_headers kwargs["headers"] = http_headers
return httpx.AsyncClient(*args, **kwargs) return httpx.AsyncClient(*args, **kwargs, **http_args)
http_client = functools.partial(http_client, **http_args)

View File

@ -1,4 +1,4 @@
from typing import Literal, Type from typing import Literal
from httpx import AsyncClient from httpx import AsyncClient
@ -7,7 +7,6 @@ from .http import http_client
class SchedulerConfig: class SchedulerConfig:
schedule_type: Literal["date", "interval", "cron"] schedule_type: Literal["date", "interval", "cron"]
schedule_setting: dict schedule_setting: dict
name: str name: str
@ -25,9 +24,7 @@ class SchedulerConfig:
return self.default_http_client return self.default_http_client
def scheduler( def scheduler(schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict) -> type[SchedulerConfig]:
schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict
) -> Type[SchedulerConfig]:
return type( return type(
"AnonymousScheduleConfig", "AnonymousScheduleConfig",
(SchedulerConfig,), (SchedulerConfig,),

View File

@ -85,7 +85,7 @@ line-length = 120
target-version = "py310" target-version = "py310"
[tool.black] [tool.black]
line-length = 120 line-length = 118
target-version = ["py310", "py311"] target-version = ["py310", "py311"]
include = '\.pyi?$' include = '\.pyi?$'
extend-exclude = ''' extend-exclude = '''