mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-04 02:26:11 +08:00
🎨 按ruff的检查调整程序代码
This commit is contained in:
parent
f232ce4c3e
commit
dba8f2a9cb
@ -6,25 +6,18 @@ require("nonebot_plugin_saa")
|
||||
|
||||
import nonebot_plugin_saa
|
||||
|
||||
from . import (
|
||||
admin_page,
|
||||
bootstrap,
|
||||
config,
|
||||
platform,
|
||||
post,
|
||||
scheduler,
|
||||
send,
|
||||
sub_manager,
|
||||
types,
|
||||
utils,
|
||||
)
|
||||
from .plugin_config import PlugConfig, plugin_config
|
||||
from . import post, send, types, utils, config, platform, bootstrap, scheduler, admin_page, sub_manager
|
||||
|
||||
__help__version__ = "0.7.3"
|
||||
nonebot_plugin_saa.enable_auto_select_bot()
|
||||
|
||||
__help__plugin__name__ = "nonebot_bison"
|
||||
__usage__ = f"本bot可以提供b站、微博等社交媒体的消息订阅,详情请查看本bot文档,或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,发送“查询订阅”或“删除订阅”管理订阅"
|
||||
__usage__ = (
|
||||
"本bot可以提供b站、微博等社交媒体的消息订阅,详情请查看本bot文档,"
|
||||
f"或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,"
|
||||
f"发送“查询订阅”或“删除订阅”管理订阅"
|
||||
)
|
||||
|
||||
__supported_adapters__ = nonebot_plugin_saa.__plugin_meta__.supported_adapters
|
||||
|
||||
@ -41,6 +34,7 @@ __plugin_meta__ = PluginMetadata(
|
||||
|
||||
__all__ = [
|
||||
"admin_page",
|
||||
"bootstrap",
|
||||
"config",
|
||||
"sub_manager",
|
||||
"post",
|
||||
|
@ -1,17 +1,16 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from nonebot import get_driver, on_command
|
||||
from nonebot.adapters.onebot.v11 import Bot
|
||||
from nonebot.adapters.onebot.v11.event import PrivateMessageEvent
|
||||
from nonebot.drivers.fastapi import Driver
|
||||
from nonebot.log import logger
|
||||
from nonebot.rule import to_me
|
||||
from nonebot.typing import T_State
|
||||
from nonebot import get_driver, on_command
|
||||
from nonebot.drivers.fastapi import Driver
|
||||
from nonebot.adapters.onebot.v11 import Bot
|
||||
from nonebot.adapters.onebot.v11.event import PrivateMessageEvent
|
||||
|
||||
from ..plugin_config import plugin_config
|
||||
from .api import router as api_router
|
||||
from ..plugin_config import plugin_config
|
||||
from .token_manager import token_manager as tm
|
||||
|
||||
STATIC_PATH = (Path(__file__).parent / "dist").resolve()
|
||||
@ -28,11 +27,9 @@ def init_fastapi():
|
||||
class SinglePageApplication(StaticFiles):
|
||||
def __init__(self, directory: os.PathLike, index="index.html"):
|
||||
self.index = index
|
||||
super().__init__(
|
||||
directory=directory, packages=None, html=True, check_dir=True
|
||||
)
|
||||
super().__init__(directory=directory, packages=None, html=True, check_dir=True)
|
||||
|
||||
def lookup_path(self, path: str) -> tuple[str, Union[os.stat_result, None]]:
|
||||
def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]:
|
||||
full_path, stat_res = super().lookup_path(path)
|
||||
if stat_res is None:
|
||||
return super().lookup_path(self.index)
|
||||
@ -45,9 +42,7 @@ def init_fastapi():
|
||||
description="nonebot-bison webui and api",
|
||||
)
|
||||
nonebot_app.include_router(api_router)
|
||||
nonebot_app.mount(
|
||||
"/", SinglePageApplication(directory=static_path), name="bison-frontend"
|
||||
)
|
||||
nonebot_app.mount("/", SinglePageApplication(directory=static_path), name="bison-frontend")
|
||||
|
||||
app = driver.server_app
|
||||
app.mount("/bison", nonebot_app, "nonebot-bison")
|
||||
@ -63,10 +58,9 @@ def init_fastapi():
|
||||
if host in ["0.0.0.0", "127.0.0.1"]:
|
||||
host = "localhost"
|
||||
logger.opt(colors=True).info(
|
||||
f"Nonebot Bison frontend will be running at: "
|
||||
f"<b><u>http://{host}:{port}/bison</u></b>"
|
||||
f"Nonebot Bison frontend will be running at: " f"<b><u>http://{host}:{port}/bison</u></b>"
|
||||
)
|
||||
logger.opt(colors=True).info(f"该页面不能被直接访问,请私聊bot <b><u>后台管理</u></b> 以获取可访问地址")
|
||||
logger.opt(colors=True).info("该页面不能被直接访问,请私聊bot <b><u>后台管理</u></b> 以获取可访问地址")
|
||||
|
||||
|
||||
def register_get_token_handler():
|
||||
@ -93,6 +87,4 @@ if (STATIC_PATH / "index.html").exists():
|
||||
else:
|
||||
logger.warning("your driver is not fastapi, webui feature will be disabled")
|
||||
else:
|
||||
logger.warning(
|
||||
"Frontend file not found, please compile it or use docker or pypi version"
|
||||
)
|
||||
logger.warning("Frontend file not found, please compile it or use docker or pypi version")
|
||||
|
@ -1,35 +1,30 @@
|
||||
import nonebot
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.param_functions import Depends
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.security.oauth2 import OAuth2PasswordBearer
|
||||
from fastapi.param_functions import Depends
|
||||
from fastapi.exceptions import HTTPException
|
||||
from nonebot_plugin_saa import TargetQQGroup
|
||||
from fastapi.security.oauth2 import OAuth2PasswordBearer
|
||||
from nonebot_plugin_saa.utils.auto_select_bot import get_bot
|
||||
|
||||
from ..apis import check_sub_target
|
||||
from ..config import (
|
||||
NoSuchSubscribeException,
|
||||
NoSuchTargetException,
|
||||
NoSuchUserException,
|
||||
config,
|
||||
)
|
||||
from ..config.db_config import SubscribeDupException
|
||||
from ..platform import platform_manager
|
||||
from ..types import Target as T_Target
|
||||
from ..types import WeightConfig
|
||||
from ..utils.get_bot import get_groups
|
||||
from ..apis import check_sub_target
|
||||
from .jwt import load_jwt, pack_jwt
|
||||
from ..types import Target as T_Target
|
||||
from ..utils.get_bot import get_groups
|
||||
from ..platform import platform_manager
|
||||
from .token_manager import token_manager
|
||||
from ..config.db_config import SubscribeDupException
|
||||
from ..config import NoSuchUserException, NoSuchTargetException, NoSuchSubscribeException, config
|
||||
from .types import (
|
||||
AddSubscribeReq,
|
||||
TokenResp,
|
||||
GlobalConf,
|
||||
PlatformConfig,
|
||||
StatusResp,
|
||||
SubscribeResp,
|
||||
PlatformConfig,
|
||||
AddSubscribeReq,
|
||||
SubscribeConfig,
|
||||
SubscribeGroupDetail,
|
||||
SubscribeResp,
|
||||
TokenResp,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["api"])
|
||||
@ -44,9 +39,7 @@ async def get_jwt_obj(token: str = Depends(oath_scheme)):
|
||||
return obj
|
||||
|
||||
|
||||
async def check_group_permission(
|
||||
groupNumber: int, token_obj: dict = Depends(get_jwt_obj)
|
||||
):
|
||||
async def check_group_permission(groupNumber: int, token_obj: dict = Depends(get_jwt_obj)):
|
||||
groups = token_obj["groups"]
|
||||
for group in groups:
|
||||
if int(groupNumber) == group["id"]:
|
||||
@ -95,15 +88,13 @@ async def auth(token: str) -> TokenResp:
|
||||
jwt_obj = {
|
||||
"id": qq,
|
||||
"type": "admin",
|
||||
"groups": list(
|
||||
map(
|
||||
lambda info: {
|
||||
"id": info["group_id"],
|
||||
"name": info["group_name"],
|
||||
},
|
||||
await get_groups(),
|
||||
)
|
||||
),
|
||||
"groups": [
|
||||
{
|
||||
"id": info["group_id"],
|
||||
"name": info["group_name"],
|
||||
}
|
||||
for info in await get_groups()
|
||||
],
|
||||
}
|
||||
ret_obj = TokenResp(
|
||||
type="admin",
|
||||
@ -134,18 +125,16 @@ async def get_subs_info(jwt_obj: dict = Depends(get_jwt_obj)) -> SubscribeResp:
|
||||
for group in groups:
|
||||
group_id = group["id"]
|
||||
raw_subs = await config.list_subscribe(TargetQQGroup(group_id=group_id))
|
||||
subs = list(
|
||||
map(
|
||||
lambda sub: SubscribeConfig(
|
||||
platformName=sub.target.platform_name,
|
||||
targetName=sub.target.target_name,
|
||||
cats=sub.categories,
|
||||
tags=sub.tags,
|
||||
target=sub.target.target,
|
||||
),
|
||||
raw_subs,
|
||||
subs = [
|
||||
SubscribeConfig(
|
||||
platformName=sub.target.platform_name,
|
||||
targetName=sub.target.target_name,
|
||||
cats=sub.categories,
|
||||
tags=sub.tags,
|
||||
target=sub.target.target,
|
||||
)
|
||||
)
|
||||
for sub in raw_subs
|
||||
]
|
||||
res[group_id] = SubscribeGroupDetail(name=group["name"], subscribes=subs)
|
||||
return res
|
||||
|
||||
@ -174,9 +163,7 @@ async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp:
|
||||
@router.delete("/subs", dependencies=[Depends(check_group_permission)])
|
||||
async def del_group_sub(groupNumber: int, platformName: str, target: str):
|
||||
try:
|
||||
await config.del_subscribe(
|
||||
TargetQQGroup(group_id=groupNumber), target, platformName
|
||||
)
|
||||
await config.del_subscribe(TargetQQGroup(group_id=groupNumber), target, platformName)
|
||||
except (NoSuchUserException, NoSuchSubscribeException):
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such user or subscribe")
|
||||
return StatusResp(ok=True, msg="")
|
||||
@ -204,13 +191,9 @@ async def get_weight_config():
|
||||
|
||||
|
||||
@router.put("/weight", dependencies=[Depends(check_is_superuser)])
|
||||
async def update_weigth_config(
|
||||
platformName: str, target: str, weight_config: WeightConfig
|
||||
):
|
||||
async def update_weigth_config(platformName: str, target: str, weight_config: WeightConfig):
|
||||
try:
|
||||
await config.update_time_weight_config(
|
||||
T_Target(target), platformName, weight_config
|
||||
)
|
||||
await config.update_time_weight_config(T_Target(target), platformName, weight_config)
|
||||
except NoSuchTargetException:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such subscribe")
|
||||
return StatusResp(ok=True, msg="")
|
||||
|
@ -1,7 +1,6 @@
|
||||
import datetime
|
||||
import random
|
||||
import string
|
||||
from typing import Optional
|
||||
import datetime
|
||||
|
||||
import jwt
|
||||
|
||||
@ -16,8 +15,8 @@ def pack_jwt(obj: dict) -> str:
|
||||
)
|
||||
|
||||
|
||||
def load_jwt(token: str) -> Optional[dict]:
|
||||
def load_jwt(token: str) -> dict | None:
|
||||
try:
|
||||
return jwt.decode(token, _key, algorithms=["HS256"])
|
||||
except:
|
||||
except Exception:
|
||||
return None
|
||||
|
@ -1,6 +1,5 @@
|
||||
import random
|
||||
import string
|
||||
from typing import Optional
|
||||
|
||||
from expiringdict import ExpiringDict
|
||||
|
||||
@ -9,7 +8,7 @@ class TokenManager:
|
||||
def __init__(self):
|
||||
self.token_manager = ExpiringDict(max_len=100, max_age_seconds=60 * 10)
|
||||
|
||||
def get_user(self, token: str) -> Optional[tuple]:
|
||||
def get_user(self, token: str) -> tuple | None:
|
||||
res = self.token_manager.get(token)
|
||||
assert res is None or isinstance(res, tuple)
|
||||
return res
|
||||
|
@ -1,2 +1,4 @@
|
||||
from .db_config import config
|
||||
from .utils import NoSuchSubscribeException, NoSuchTargetException, NoSuchUserException
|
||||
from .db_config import config as config
|
||||
from .utils import NoSuchUserException as NoSuchUserException
|
||||
from .utils import NoSuchTargetException as NoSuchTargetException
|
||||
from .utils import NoSuchSubscribeException as NoSuchSubscribeException
|
||||
|
@ -1,20 +1,19 @@
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
import json
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import DefaultDict, Literal, Mapping, TypedDict
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
import nonebot
|
||||
from nonebot.log import logger
|
||||
from tinydb import Query, TinyDB
|
||||
|
||||
from ..utils import Singleton
|
||||
from ..types import User, Target
|
||||
from ..platform import platform_manager
|
||||
from ..plugin_config import plugin_config
|
||||
from ..types import Target, User
|
||||
from ..utils import Singleton
|
||||
from .utils import NoSuchSubscribeException, NoSuchUserException
|
||||
from .utils import NoSuchUserException, NoSuchSubscribeException
|
||||
|
||||
supported_target_type = platform_manager.keys()
|
||||
|
||||
@ -89,17 +88,16 @@ class Config(metaclass=Singleton):
|
||||
self.target_user_cat_cache = {}
|
||||
self.target_user_tag_cache = {}
|
||||
self.target_list = {}
|
||||
self.next_index: DefaultDict[str, int] = defaultdict(lambda: 0)
|
||||
self.next_index: defaultdict[str, int] = defaultdict(lambda: 0)
|
||||
else:
|
||||
self.available = False
|
||||
|
||||
def add_subscribe(
|
||||
self, user, user_type, target, target_name, target_type, cats, tags
|
||||
):
|
||||
def add_subscribe(self, user, user_type, target, target_name, target_type, cats, tags):
|
||||
user_query = Query()
|
||||
query = (user_query.user == user) & (user_query.user_type == user_type)
|
||||
if user_data := self.user_target.get(query):
|
||||
# update
|
||||
assert not isinstance(user_data, list)
|
||||
subs: list = user_data.get("subs", [])
|
||||
subs.append(
|
||||
{
|
||||
@ -132,9 +130,8 @@ class Config(metaclass=Singleton):
|
||||
|
||||
def list_subscribe(self, user, user_type) -> list[SubscribeContent]:
|
||||
query = Query()
|
||||
if user_sub := self.user_target.get(
|
||||
(query.user == user) & (query.user_type == user_type)
|
||||
):
|
||||
if user_sub := self.user_target.get((query.user == user) & (query.user_type == user_type)):
|
||||
assert not isinstance(user_sub, list)
|
||||
return user_sub["subs"]
|
||||
return []
|
||||
|
||||
@ -146,6 +143,7 @@ class Config(metaclass=Singleton):
|
||||
query = (user_query.user == user) & (user_query.user_type == user_type)
|
||||
if not (query_res := self.user_target.get(query)):
|
||||
raise NoSuchUserException()
|
||||
assert not isinstance(query_res, list)
|
||||
subs = query_res.get("subs", [])
|
||||
for idx, sub in enumerate(subs):
|
||||
if sub.get("target") == target and sub.get("target_type") == target_type:
|
||||
@ -155,13 +153,12 @@ class Config(metaclass=Singleton):
|
||||
return
|
||||
raise NoSuchSubscribeException()
|
||||
|
||||
def update_subscribe(
|
||||
self, user, user_type, target, target_name, target_type, cats, tags
|
||||
):
|
||||
def update_subscribe(self, user, user_type, target, target_name, target_type, cats, tags):
|
||||
user_query = Query()
|
||||
query = (user_query.user == user) & (user_query.user_type == user_type)
|
||||
if user_data := self.user_target.get(query):
|
||||
# update
|
||||
assert not isinstance(user_data, list)
|
||||
subs: list = user_data.get("subs", [])
|
||||
find_flag = False
|
||||
for item in subs:
|
||||
@ -182,19 +179,13 @@ class Config(metaclass=Singleton):
|
||||
|
||||
def update_send_cache(self):
|
||||
res = {target_type: defaultdict(list) for target_type in supported_target_type}
|
||||
cat_res = {
|
||||
target_type: defaultdict(lambda: defaultdict(list))
|
||||
for target_type in supported_target_type
|
||||
}
|
||||
tag_res = {
|
||||
target_type: defaultdict(lambda: defaultdict(list))
|
||||
for target_type in supported_target_type
|
||||
}
|
||||
cat_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
|
||||
tag_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
|
||||
# res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
|
||||
to_del = []
|
||||
for user in self.user_target.all():
|
||||
for sub in user.get("subs", []):
|
||||
if not sub.get("target_type") in supported_target_type:
|
||||
if sub.get("target_type") not in supported_target_type:
|
||||
to_del.append(
|
||||
{
|
||||
"user": user["user"],
|
||||
@ -204,36 +195,28 @@ class Config(metaclass=Singleton):
|
||||
}
|
||||
)
|
||||
continue
|
||||
res[sub["target_type"]][sub["target"]].append(
|
||||
User(user["user"], user["user_type"])
|
||||
)
|
||||
cat_res[sub["target_type"]][sub["target"]][
|
||||
"{}-{}".format(user["user_type"], user["user"])
|
||||
] = sub["cats"]
|
||||
tag_res[sub["target_type"]][sub["target"]][
|
||||
"{}-{}".format(user["user_type"], user["user"])
|
||||
] = sub["tags"]
|
||||
res[sub["target_type"]][sub["target"]].append(User(user["user"], user["user_type"]))
|
||||
cat_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[
|
||||
"cats"
|
||||
]
|
||||
tag_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[
|
||||
"tags"
|
||||
]
|
||||
self.target_user_cache = res
|
||||
self.target_user_cat_cache = cat_res
|
||||
self.target_user_tag_cache = tag_res
|
||||
for target_type in self.target_user_cache:
|
||||
self.target_list[target_type] = list(
|
||||
self.target_user_cache[target_type].keys()
|
||||
)
|
||||
self.target_list[target_type] = list(self.target_user_cache[target_type].keys())
|
||||
|
||||
logger.info(f"Deleting {to_del}")
|
||||
for d in to_del:
|
||||
self.del_subscribe(**d)
|
||||
|
||||
def get_sub_category(self, target_type, target, user_type, user):
|
||||
return self.target_user_cat_cache[target_type][target][
|
||||
"{}-{}".format(user_type, user)
|
||||
]
|
||||
return self.target_user_cat_cache[target_type][target][f"{user_type}-{user}"]
|
||||
|
||||
def get_sub_tags(self, target_type, target, user_type, user):
|
||||
return self.target_user_tag_cache[target_type][target][
|
||||
"{}-{}".format(user_type, user)
|
||||
]
|
||||
return self.target_user_tag_cache[target_type][target][f"{user_type}-{user}"]
|
||||
|
||||
def get_next_target(self, target_type):
|
||||
# FIXME 插入或删除target后对队列的影响(但是并不是大问题
|
||||
|
@ -1,19 +1,19 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, time
|
||||
from typing import Awaitable, Callable, Optional, Sequence
|
||||
from datetime import time, datetime
|
||||
from collections.abc import Callable, Sequence, Awaitable
|
||||
|
||||
from nonebot_plugin_datastore import create_session
|
||||
from nonebot_plugin_saa import PlatformTarget
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy import func, delete, select
|
||||
from nonebot_plugin_saa import PlatformTarget
|
||||
from nonebot_plugin_datastore import create_session
|
||||
|
||||
from ..types import Category, PlatformWeightConfigResp, Tag
|
||||
from ..types import Tag
|
||||
from ..types import Target as T_Target
|
||||
from ..types import TimeWeightConfig, UserSubInfo, WeightConfig
|
||||
from .db_model import ScheduleTimeWeight, Subscribe, Target, User
|
||||
from .utils import NoSuchTargetException
|
||||
from .db_model import User, Target, Subscribe, ScheduleTimeWeight
|
||||
from ..types import Category, UserSubInfo, WeightConfig, TimeWeightConfig, PlatformWeightConfigResp
|
||||
|
||||
|
||||
def _get_time():
|
||||
@ -48,23 +48,17 @@ class DBConfig:
|
||||
):
|
||||
async with create_session() as session:
|
||||
db_user_stmt = select(User).where(User.user_target == user.dict())
|
||||
db_user: Optional[User] = await session.scalar(db_user_stmt)
|
||||
db_user: User | None = await session.scalar(db_user_stmt)
|
||||
if not db_user:
|
||||
db_user = User(user_target=user.dict())
|
||||
session.add(db_user)
|
||||
db_target_stmt = (
|
||||
select(Target)
|
||||
.where(Target.platform_name == platform_name)
|
||||
.where(Target.target == target)
|
||||
select(Target).where(Target.platform_name == platform_name).where(Target.target == target)
|
||||
)
|
||||
db_target: Optional[Target] = await session.scalar(db_target_stmt)
|
||||
db_target: Target | None = await session.scalar(db_target_stmt)
|
||||
if not db_target:
|
||||
db_target = Target(
|
||||
target=target, platform_name=platform_name, target_name=target_name
|
||||
)
|
||||
await asyncio.gather(
|
||||
*[hook(platform_name, target) for hook in self.add_target_hook]
|
||||
)
|
||||
db_target = Target(target=target, platform_name=platform_name, target_name=target_name)
|
||||
await asyncio.gather(*[hook(platform_name, target) for hook in self.add_target_hook])
|
||||
else:
|
||||
db_target.target_name = target_name
|
||||
subscribe = Subscribe(
|
||||
@ -96,44 +90,25 @@ class DBConfig:
|
||||
"""获取数据库中带有user、target信息的subscribe数据"""
|
||||
async with create_session() as session:
|
||||
query_stmt = (
|
||||
select(Subscribe)
|
||||
.join(User)
|
||||
.options(selectinload(Subscribe.target), selectinload(Subscribe.user))
|
||||
select(Subscribe).join(User).options(selectinload(Subscribe.target), selectinload(Subscribe.user))
|
||||
)
|
||||
subs = (await session.scalars(query_stmt)).all()
|
||||
|
||||
return subs
|
||||
|
||||
async def del_subscribe(
|
||||
self, user: PlatformTarget, target: str, platform_name: str
|
||||
):
|
||||
async def del_subscribe(self, user: PlatformTarget, target: str, platform_name: str):
|
||||
async with create_session() as session:
|
||||
user_obj = await session.scalar(
|
||||
select(User).where(User.user_target == user.dict())
|
||||
)
|
||||
user_obj = await session.scalar(select(User).where(User.user_target == user.dict()))
|
||||
target_obj = await session.scalar(
|
||||
select(Target).where(
|
||||
Target.platform_name == platform_name, Target.target == target
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
delete(Subscribe).where(
|
||||
Subscribe.user == user_obj, Subscribe.target == target_obj
|
||||
)
|
||||
select(Target).where(Target.platform_name == platform_name, Target.target == target)
|
||||
)
|
||||
await session.execute(delete(Subscribe).where(Subscribe.user == user_obj, Subscribe.target == target_obj))
|
||||
target_count = await session.scalar(
|
||||
select(func.count())
|
||||
.select_from(Subscribe)
|
||||
.where(Subscribe.target == target_obj)
|
||||
select(func.count()).select_from(Subscribe).where(Subscribe.target == target_obj)
|
||||
)
|
||||
if target_count == 0:
|
||||
# delete empty target
|
||||
await asyncio.gather(
|
||||
*[
|
||||
hook(platform_name, T_Target(target))
|
||||
for hook in self.delete_target_hook
|
||||
]
|
||||
)
|
||||
await asyncio.gather(*[hook(platform_name, T_Target(target)) for hook in self.delete_target_hook])
|
||||
await session.commit()
|
||||
|
||||
async def update_subscribe(
|
||||
@ -165,29 +140,22 @@ class DBConfig:
|
||||
async def get_platform_target(self, platform_name: str) -> Sequence[Target]:
|
||||
async with create_session() as sess:
|
||||
subq = select(Subscribe.target_id).distinct().subquery()
|
||||
query = (
|
||||
select(Target).join(subq).where(Target.platform_name == platform_name)
|
||||
)
|
||||
query = select(Target).join(subq).where(Target.platform_name == platform_name)
|
||||
return (await sess.scalars(query)).all()
|
||||
|
||||
async def get_time_weight_config(
|
||||
self, target: T_Target, platform_name: str
|
||||
) -> WeightConfig:
|
||||
async def get_time_weight_config(self, target: T_Target, platform_name: str) -> WeightConfig:
|
||||
async with create_session() as sess:
|
||||
time_weight_conf = (
|
||||
await sess.scalars(
|
||||
select(ScheduleTimeWeight)
|
||||
.where(
|
||||
Target.platform_name == platform_name, Target.target == target
|
||||
)
|
||||
.where(Target.platform_name == platform_name, Target.target == target)
|
||||
.join(Target)
|
||||
)
|
||||
).all()
|
||||
targetObj = await sess.scalar(
|
||||
select(Target).where(
|
||||
Target.platform_name == platform_name, Target.target == target
|
||||
)
|
||||
select(Target).where(Target.platform_name == platform_name, Target.target == target)
|
||||
)
|
||||
assert targetObj
|
||||
return WeightConfig(
|
||||
default=targetObj.default_schedule_weight,
|
||||
time_config=[
|
||||
@ -200,22 +168,16 @@ class DBConfig:
|
||||
],
|
||||
)
|
||||
|
||||
async def update_time_weight_config(
|
||||
self, target: T_Target, platform_name: str, conf: WeightConfig
|
||||
):
|
||||
async def update_time_weight_config(self, target: T_Target, platform_name: str, conf: WeightConfig):
|
||||
async with create_session() as sess:
|
||||
targetObj = await sess.scalar(
|
||||
select(Target).where(
|
||||
Target.platform_name == platform_name, Target.target == target
|
||||
)
|
||||
select(Target).where(Target.platform_name == platform_name, Target.target == target)
|
||||
)
|
||||
if not targetObj:
|
||||
raise NoSuchTargetException()
|
||||
target_id = targetObj.id
|
||||
targetObj.default_schedule_weight = conf.default
|
||||
delete_statement = delete(ScheduleTimeWeight).where(
|
||||
ScheduleTimeWeight.target_id == target_id
|
||||
)
|
||||
delete_statement = delete(ScheduleTimeWeight).where(ScheduleTimeWeight.target_id == target_id)
|
||||
await sess.execute(delete_statement)
|
||||
for time_conf in conf.time_config:
|
||||
new_conf = ScheduleTimeWeight(
|
||||
@ -243,18 +205,13 @@ class DBConfig:
|
||||
key = f"{target.platform_name}-{target.target}"
|
||||
weight = target.default_schedule_weight
|
||||
for time_conf in target.time_weight:
|
||||
if (
|
||||
time_conf.start_time <= cur_time
|
||||
and time_conf.end_time > cur_time
|
||||
):
|
||||
if time_conf.start_time <= cur_time and time_conf.end_time > cur_time:
|
||||
weight = time_conf.weight
|
||||
break
|
||||
res[key] = weight
|
||||
return res
|
||||
|
||||
async def get_platform_target_subscribers(
|
||||
self, platform_name: str, target: T_Target
|
||||
) -> list[UserSubInfo]:
|
||||
async def get_platform_target_subscribers(self, platform_name: str, target: T_Target) -> list[UserSubInfo]:
|
||||
async with create_session() as sess:
|
||||
query = (
|
||||
select(Subscribe)
|
||||
@ -263,16 +220,14 @@ class DBConfig:
|
||||
.options(selectinload(Subscribe.user))
|
||||
)
|
||||
subsribes = (await sess.scalars(query)).all()
|
||||
return list(
|
||||
map(
|
||||
lambda subscribe: UserSubInfo(
|
||||
PlatformTarget.deserialize(subscribe.user.user_target),
|
||||
subscribe.categories,
|
||||
subscribe.tags,
|
||||
),
|
||||
subsribes,
|
||||
return [
|
||||
UserSubInfo(
|
||||
PlatformTarget.deserialize(subscribe.user.user_target),
|
||||
subscribe.categories,
|
||||
subscribe.tags,
|
||||
)
|
||||
)
|
||||
for subscribe in subsribes
|
||||
]
|
||||
|
||||
async def get_all_weight_config(
|
||||
self,
|
||||
@ -281,9 +236,7 @@ class DBConfig:
|
||||
async with create_session() as sess:
|
||||
query = select(Target)
|
||||
targets = (await sess.scalars(query)).all()
|
||||
query = select(ScheduleTimeWeight).options(
|
||||
selectinload(ScheduleTimeWeight.target)
|
||||
)
|
||||
query = select(ScheduleTimeWeight).options(selectinload(ScheduleTimeWeight.target))
|
||||
time_weights = (await sess.scalars(query)).all()
|
||||
|
||||
for target in targets:
|
||||
@ -293,9 +246,7 @@ class DBConfig:
|
||||
target=T_Target(target.target),
|
||||
target_name=target.target_name,
|
||||
platform_name=platform_name,
|
||||
weight=WeightConfig(
|
||||
default=target.default_schedule_weight, time_config=[]
|
||||
),
|
||||
weight=WeightConfig(default=target.default_schedule_weight, time_config=[]),
|
||||
)
|
||||
|
||||
for time_weight_config in time_weights:
|
||||
|
@ -1,22 +1,17 @@
|
||||
from nonebot.log import logger
|
||||
from nonebot_plugin_datastore.db import get_engine
|
||||
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
|
||||
|
||||
from .db_model import User, Target, Subscribe
|
||||
from .config_legacy import Config, ConfigContent, drop
|
||||
from .db_model import Subscribe, Target, User
|
||||
|
||||
|
||||
async def data_migrate():
|
||||
config = Config()
|
||||
if config.available:
|
||||
logger.warning("You are still using legacy db, migrating to sqlite")
|
||||
all_subs: list[ConfigContent] = list(
|
||||
map(
|
||||
lambda item: ConfigContent(**item),
|
||||
config.get_all_subscribe().all(),
|
||||
)
|
||||
)
|
||||
all_subs: list[ConfigContent] = [ConfigContent(**item) for item in config.get_all_subscribe().all()]
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
user_to_create = []
|
||||
subscribe_to_create = []
|
||||
@ -37,8 +32,7 @@ async def data_migrate():
|
||||
if key in user_sub_set:
|
||||
# a user subscribe a target twice
|
||||
logger.error(
|
||||
f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次,"
|
||||
"随机采用了一个订阅"
|
||||
f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次,随机采用了一个订阅" # noqa: E501
|
||||
)
|
||||
continue
|
||||
user_sub_set.add(key)
|
||||
@ -69,11 +63,7 @@ async def data_migrate():
|
||||
tags=sub["tags"],
|
||||
)
|
||||
subscribe_to_create.append(subscribe_obj)
|
||||
sess.add_all(
|
||||
user_to_create
|
||||
+ list(map(lambda x: x[0], platform_target_map.values()))
|
||||
+ subscribe_to_create
|
||||
)
|
||||
sess.add_all(user_to_create + [x[0] for x in platform_target_map.values()] + subscribe_to_create)
|
||||
await sess.commit()
|
||||
drop()
|
||||
logger.info("migrate success")
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""init db
|
||||
|
||||
Revision ID: 0571870f5222
|
||||
Revises:
|
||||
Revises:
|
||||
Create Date: 2022-03-21 19:18:13.762626
|
||||
|
||||
"""
|
||||
|
@ -5,7 +5,6 @@ Revises: 5f3370328e44
|
||||
Create Date: 2023-01-15 19:04:54.987491
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
@ -5,7 +5,6 @@ Revises: 0571870f5222
|
||||
Create Date: 2022-03-26 19:46:50.910721
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@ -18,14 +17,10 @@ depends_on = None
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("subscribe", schema=None) as batch_op:
|
||||
batch_op.create_unique_constraint(
|
||||
"unique-subscribe-constraint", ["target_id", "user_id"]
|
||||
)
|
||||
batch_op.create_unique_constraint("unique-subscribe-constraint", ["target_id", "user_id"])
|
||||
|
||||
with op.batch_alter_table("target", schema=None) as batch_op:
|
||||
batch_op.create_unique_constraint(
|
||||
"unique-target-constraint", ["target", "platform_name"]
|
||||
)
|
||||
batch_op.create_unique_constraint("unique-target-constraint", ["target", "platform_name"])
|
||||
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"])
|
||||
|
@ -1,15 +1,14 @@
|
||||
from abc import ABC
|
||||
|
||||
from nonebot_plugin_saa.utils import AllSupportedPlatformTarget as UserInfo
|
||||
from pydantic import BaseModel
|
||||
from nonebot_plugin_saa.utils import AllSupportedPlatformTarget as UserInfo
|
||||
|
||||
from ....types import Category, Tag
|
||||
from ....types import Tag, Category
|
||||
|
||||
|
||||
class NBESFBase(BaseModel, ABC):
|
||||
|
||||
version: int # 表示nbesf格式版本,有效版本从1开始
|
||||
groups: list = list()
|
||||
groups: list = []
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
@ -1,25 +1,26 @@
|
||||
from typing import cast
|
||||
from collections import defaultdict
|
||||
from typing import Callable, cast
|
||||
from collections.abc import Callable
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot_plugin_datastore.db import create_session
|
||||
from nonebot_plugin_saa import PlatformTarget
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm.strategy_options import selectinload
|
||||
from nonebot.log import logger
|
||||
from sqlalchemy.sql.selectable import Select
|
||||
from nonebot_plugin_saa import PlatformTarget
|
||||
from nonebot_plugin_datastore.db import create_session
|
||||
from sqlalchemy.orm.strategy_options import selectinload
|
||||
|
||||
from ..db_model import Subscribe, User
|
||||
from .nbesf_model import NBESFBase, v1, v2
|
||||
from .utils import NBESFVerMatchErr
|
||||
from ..db_model import User, Subscribe
|
||||
from .nbesf_model import NBESFBase, v1, v2
|
||||
|
||||
|
||||
async def subscribes_export(selector: Callable[[Select], Select]) -> v2.SubGroup:
|
||||
|
||||
"""
|
||||
将Bison订阅导出为 Nonebot Bison Exchangable Subscribes File 标准格式的 SubGroup 类型数据
|
||||
|
||||
selector:
|
||||
对 sqlalchemy Select 对象的操作函数,用于限定查询范围 e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group")
|
||||
对 sqlalchemy Select 对象的操作函数,用于限定查询范围
|
||||
e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group")
|
||||
"""
|
||||
async with create_session() as sess:
|
||||
sub_stmt = select(Subscribe).join(User)
|
||||
|
@ -1,23 +1,22 @@
|
||||
from collections import defaultdict
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from pkgutil import iter_modules
|
||||
from typing import DefaultDict, Type
|
||||
from collections import defaultdict
|
||||
from importlib import import_module
|
||||
|
||||
from .platform import Platform, make_no_target_group
|
||||
|
||||
_package_dir = str(Path(__file__).resolve().parent)
|
||||
for (_, module_name, _) in iter_modules([_package_dir]):
|
||||
for _, module_name, _ in iter_modules([_package_dir]):
|
||||
import_module(f"{__name__}.{module_name}")
|
||||
|
||||
|
||||
_platform_list: DefaultDict[str, list[Type[Platform]]] = defaultdict(list)
|
||||
_platform_list: defaultdict[str, list[type[Platform]]] = defaultdict(list)
|
||||
for _platform in Platform.registry:
|
||||
if not _platform.enabled:
|
||||
continue
|
||||
_platform_list[_platform.platform_name].append(_platform)
|
||||
|
||||
platform_manager: dict[str, Type[Platform]] = dict()
|
||||
platform_manager: dict[str, type[Platform]] = {}
|
||||
for name, platform_list in _platform_list.items():
|
||||
if len(platform_list) == 1:
|
||||
platform_manager[name] = platform_list[0]
|
||||
|
@ -1,25 +1,23 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from bs4 import BeautifulSoup as bs
|
||||
from httpx import AsyncClient
|
||||
from nonebot.plugin import require
|
||||
from bs4 import BeautifulSoup as bs
|
||||
|
||||
from ..post import Post
|
||||
from ..types import Category, RawPost, Target
|
||||
from ..types import Target, RawPost, Category
|
||||
from ..utils.scheduler_config import SchedulerConfig
|
||||
from .platform import CategoryNotRecognize, NewMessage, StatusChange
|
||||
from .platform import NewMessage, StatusChange, CategoryNotRecognize
|
||||
|
||||
|
||||
class ArknightsSchedConf(SchedulerConfig):
|
||||
|
||||
name = "arknights"
|
||||
schedule_type = "interval"
|
||||
schedule_setting = {"seconds": 30}
|
||||
|
||||
|
||||
class Arknights(NewMessage):
|
||||
|
||||
categories = {1: "游戏公告"}
|
||||
platform_name = "arknights"
|
||||
name = "明日方舟游戏信息"
|
||||
@ -30,9 +28,7 @@ class Arknights(NewMessage):
|
||||
has_target = False
|
||||
|
||||
@classmethod
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||
return "明日方舟游戏信息"
|
||||
|
||||
async def get_sub_list(self, _) -> list[RawPost]:
|
||||
@ -92,7 +88,6 @@ class Arknights(NewMessage):
|
||||
|
||||
|
||||
class AkVersion(StatusChange):
|
||||
|
||||
categories = {2: "更新信息"}
|
||||
platform_name = "arknights"
|
||||
name = "明日方舟游戏信息"
|
||||
@ -103,15 +98,11 @@ class AkVersion(StatusChange):
|
||||
has_target = False
|
||||
|
||||
@classmethod
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||
return "明日方舟游戏信息"
|
||||
|
||||
async def get_status(self, _):
|
||||
res_ver = await self.client.get(
|
||||
"https://ak-conf.hypergryph.com/config/prod/official/IOS/version"
|
||||
)
|
||||
res_ver = await self.client.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version")
|
||||
res_preanounce = await self.client.get(
|
||||
"https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json"
|
||||
)
|
||||
@ -121,20 +112,10 @@ class AkVersion(StatusChange):
|
||||
|
||||
def compare_status(self, _, old_status, new_status):
|
||||
res = []
|
||||
if (
|
||||
old_status.get("preAnnounceType") == 2
|
||||
and new_status.get("preAnnounceType") == 0
|
||||
):
|
||||
res.append(
|
||||
Post("arknights", text="登录界面维护公告上线(大概是开始维护了)", target_name="明日方舟更新信息")
|
||||
)
|
||||
elif (
|
||||
old_status.get("preAnnounceType") == 0
|
||||
and new_status.get("preAnnounceType") == 2
|
||||
):
|
||||
res.append(
|
||||
Post("arknights", text="登录界面维护公告下线(大概是开服了,冲!)", target_name="明日方舟更新信息")
|
||||
)
|
||||
if old_status.get("preAnnounceType") == 2 and new_status.get("preAnnounceType") == 0:
|
||||
res.append(Post("arknights", text="登录界面维护公告上线(大概是开始维护了)", target_name="明日方舟更新信息")) # noqa: E501
|
||||
elif old_status.get("preAnnounceType") == 0 and new_status.get("preAnnounceType") == 2:
|
||||
res.append(Post("arknights", text="登录界面维护公告下线(大概是开服了,冲!)", target_name="明日方舟更新信息")) # noqa: E501
|
||||
if old_status.get("clientVersion") != new_status.get("clientVersion"):
|
||||
res.append(Post("arknights", text="游戏本体更新(大更新)", target_name="明日方舟更新信息"))
|
||||
if old_status.get("resVersion") != new_status.get("resVersion"):
|
||||
@ -149,7 +130,6 @@ class AkVersion(StatusChange):
|
||||
|
||||
|
||||
class MonsterSiren(NewMessage):
|
||||
|
||||
categories = {3: "塞壬唱片新闻"}
|
||||
platform_name = "arknights"
|
||||
name = "明日方舟游戏信息"
|
||||
@ -160,15 +140,11 @@ class MonsterSiren(NewMessage):
|
||||
has_target = False
|
||||
|
||||
@classmethod
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||
return "明日方舟游戏信息"
|
||||
|
||||
async def get_sub_list(self, _) -> list[RawPost]:
|
||||
raw_data = await self.client.get(
|
||||
"https://monster-siren.hypergryph.com/api/news"
|
||||
)
|
||||
raw_data = await self.client.get("https://monster-siren.hypergryph.com/api/news")
|
||||
return raw_data.json()["data"]["list"]
|
||||
|
||||
def get_id(self, post: RawPost) -> Any:
|
||||
@ -182,14 +158,12 @@ class MonsterSiren(NewMessage):
|
||||
|
||||
async def parse(self, raw_post: RawPost) -> Post:
|
||||
url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}'
|
||||
res = await self.client.get(
|
||||
f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}'
|
||||
)
|
||||
res = await self.client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}')
|
||||
raw_data = res.json()
|
||||
content = raw_data["data"]["content"]
|
||||
content = content.replace("</p>", "</p>\n")
|
||||
soup = bs(content, "html.parser")
|
||||
imgs = list(map(lambda x: x["src"], soup("img")))
|
||||
imgs = [x["src"] for x in soup("img")]
|
||||
text = f'{raw_post["title"]}\n{soup.text.strip()}'
|
||||
return Post(
|
||||
"monster-siren",
|
||||
@ -203,7 +177,6 @@ class MonsterSiren(NewMessage):
|
||||
|
||||
|
||||
class TerraHistoricusComic(NewMessage):
|
||||
|
||||
categories = {4: "泰拉记事社漫画"}
|
||||
platform_name = "arknights"
|
||||
name = "明日方舟游戏信息"
|
||||
@ -214,15 +187,11 @@ class TerraHistoricusComic(NewMessage):
|
||||
has_target = False
|
||||
|
||||
@classmethod
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||
return "明日方舟游戏信息"
|
||||
|
||||
async def get_sub_list(self, _) -> list[RawPost]:
|
||||
raw_data = await self.client.get(
|
||||
"https://terra-historicus.hypergryph.com/api/recentUpdate"
|
||||
)
|
||||
raw_data = await self.client.get("https://terra-historicus.hypergryph.com/api/recentUpdate")
|
||||
return raw_data.json()["data"]
|
||||
|
||||
def get_id(self, post: RawPost) -> Any:
|
||||
|
@ -1,14 +1,14 @@
|
||||
import json
|
||||
import re
|
||||
import json
|
||||
from typing import Any
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Literal, Optional
|
||||
from typing_extensions import Self
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from httpx import AsyncClient
|
||||
from nonebot.log import logger
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from ..post import Post
|
||||
from ..types import ApiError, Category, RawPost, Tag, Target
|
||||
@ -25,9 +25,7 @@ class BilibiliSchedConf(SchedulerConfig):
|
||||
cookie_expire_time = timedelta(hours=5)
|
||||
|
||||
def __init__(self):
|
||||
self._client_refresh_time = datetime(
|
||||
year=2000, month=1, day=1
|
||||
) # an expired time
|
||||
self._client_refresh_time = datetime(year=2000, month=1, day=1) # an expired time
|
||||
super().__init__()
|
||||
|
||||
async def _init_session(self):
|
||||
@ -69,12 +67,8 @@ class Bilibili(NewMessage):
|
||||
parse_target_promot = "请输入用户主页的链接"
|
||||
|
||||
@classmethod
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
res = await client.get(
|
||||
"https://api.bilibili.com/x/web-interface/card", params={"mid": target}
|
||||
)
|
||||
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||
res = await client.get("https://api.bilibili.com/x/web-interface/card", params={"mid": target})
|
||||
res.raise_for_status()
|
||||
res_data = res.json()
|
||||
if res_data["code"]:
|
||||
@ -129,12 +123,7 @@ class Bilibili(NewMessage):
|
||||
return self._do_get_category(post_type)
|
||||
|
||||
def get_tags(self, raw_post: RawPost) -> list[Tag]:
|
||||
return [
|
||||
*map(
|
||||
lambda tp: tp["topic_name"],
|
||||
raw_post["display"]["topic_info"]["topic_details"],
|
||||
)
|
||||
]
|
||||
return [*(tp["topic_name"] for tp in raw_post["display"]["topic_info"]["topic_details"])]
|
||||
|
||||
def _get_info(self, post_type: Category, card) -> tuple[str, list]:
|
||||
if post_type == 1:
|
||||
@ -178,24 +167,16 @@ class Bilibili(NewMessage):
|
||||
url = ""
|
||||
if post_type == 1:
|
||||
# 一般动态
|
||||
url = "https://t.bilibili.com/{}".format(
|
||||
raw_post["desc"]["dynamic_id_str"]
|
||||
)
|
||||
url = "https://t.bilibili.com/{}".format(raw_post["desc"]["dynamic_id_str"])
|
||||
elif post_type == 2:
|
||||
# 专栏文章
|
||||
url = "https://www.bilibili.com/read/cv{}".format(
|
||||
raw_post["desc"]["rid"]
|
||||
)
|
||||
url = "https://www.bilibili.com/read/cv{}".format(raw_post["desc"]["rid"])
|
||||
elif post_type == 3:
|
||||
# 视频
|
||||
url = "https://www.bilibili.com/video/{}".format(
|
||||
raw_post["desc"]["bvid"]
|
||||
)
|
||||
url = "https://www.bilibili.com/video/{}".format(raw_post["desc"]["bvid"])
|
||||
elif post_type == 4:
|
||||
# 纯文字
|
||||
url = "https://t.bilibili.com/{}".format(
|
||||
raw_post["desc"]["dynamic_id_str"]
|
||||
)
|
||||
url = "https://t.bilibili.com/{}".format(raw_post["desc"]["dynamic_id_str"])
|
||||
text, pic = self._get_info(post_type, card_content)
|
||||
elif post_type == 5:
|
||||
# 转发
|
||||
@ -261,10 +242,7 @@ class Bilibililive(StatusChange):
|
||||
def get_live_action(self, old_info: Self) -> "Bilibililive.LiveAction":
|
||||
status = Bilibililive.LiveStatus
|
||||
action = Bilibililive.LiveAction
|
||||
if (
|
||||
old_info.live_status in [status.OFF, status.CYCLE]
|
||||
and self.live_status == status.ON
|
||||
):
|
||||
if old_info.live_status in [status.OFF, status.CYCLE] and self.live_status == status.ON:
|
||||
return action.TURN_ON
|
||||
elif old_info.live_status == status.ON and self.live_status in [
|
||||
status.OFF,
|
||||
@ -281,12 +259,8 @@ class Bilibililive(StatusChange):
|
||||
return action.OFF
|
||||
|
||||
@classmethod
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
res = await client.get(
|
||||
"https://api.bilibili.com/x/web-interface/card", params={"mid": target}
|
||||
)
|
||||
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||
res = await client.get("https://api.bilibili.com/x/web-interface/card", params={"mid": target})
|
||||
res_data = json.loads(res.text)
|
||||
if res_data["code"]:
|
||||
return None
|
||||
@ -382,9 +356,7 @@ class BilibiliBangumi(StatusChange):
|
||||
_url = "https://api.bilibili.com/pgc/review/user"
|
||||
|
||||
@classmethod
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||
res = await client.get(cls._url, params={"media_id": target})
|
||||
res_data = res.json()
|
||||
if res_data["code"]:
|
||||
|
@ -1,15 +1,14 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
from ..post import Post
|
||||
from ..types import RawPost, Target
|
||||
from ..utils import scheduler
|
||||
from .platform import NewMessage
|
||||
from ..types import Target, RawPost
|
||||
|
||||
|
||||
class FF14(NewMessage):
|
||||
|
||||
categories = {}
|
||||
platform_name = "ff14"
|
||||
name = "最终幻想XIV官方公告"
|
||||
@ -21,9 +20,7 @@ class FF14(NewMessage):
|
||||
has_target = False
|
||||
|
||||
@classmethod
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||
return "最终幻想XIV官方公告"
|
||||
|
||||
async def get_sub_list(self, _) -> list[RawPost]:
|
||||
|
@ -2,15 +2,15 @@ import re
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
from httpx import AsyncClient
|
||||
from nonebot.log import logger
|
||||
from bs4 import Tag, BeautifulSoup
|
||||
from nonebot.plugin import require
|
||||
|
||||
from ..post import Post
|
||||
from ..types import Category, RawPost, Target
|
||||
from ..types import Target, RawPost, Category
|
||||
from ..utils import SchedulerConfig, http_client
|
||||
from .platform import CategoryNotRecognize, CategoryNotSupport, NewMessage
|
||||
from .platform import NewMessage, CategoryNotSupport, CategoryNotRecognize
|
||||
|
||||
|
||||
class McbbsnewsSchedConf(SchedulerConfig):
|
||||
@ -134,9 +134,9 @@ class McbbsNews(NewMessage):
|
||||
if categoty_name in category_values:
|
||||
category_id = category_keys[category_values.index(categoty_name)]
|
||||
elif categoty_name in known_category_values:
|
||||
raise CategoryNotSupport("McbbsNews订阅暂不支持 {}".format(categoty_name))
|
||||
raise CategoryNotSupport(f"McbbsNews订阅暂不支持 {categoty_name}")
|
||||
else:
|
||||
raise CategoryNotRecognize("Mcbbsnews订阅尚未识别 {}".format(categoty_name))
|
||||
raise CategoryNotRecognize(f"Mcbbsnews订阅尚未识别 {categoty_name}")
|
||||
return category_id
|
||||
|
||||
async def parse(self, post: RawPost) -> Post:
|
||||
@ -170,7 +170,7 @@ class McbbsNews(NewMessage):
|
||||
一般而言每条新闻的长度都很可观,图片生成时间比较喜人
|
||||
"""
|
||||
require("nonebot_plugin_htmlrender")
|
||||
from nonebot_plugin_htmlrender import capture_element, text_to_pic
|
||||
from nonebot_plugin_htmlrender import text_to_pic, capture_element
|
||||
|
||||
try:
|
||||
assert url
|
||||
@ -181,7 +181,7 @@ class McbbsNews(NewMessage):
|
||||
device_scale_factor=3,
|
||||
)
|
||||
assert pic_data
|
||||
except:
|
||||
except Exception:
|
||||
err_info = traceback.format_exc()
|
||||
logger.warning(f"渲染错误:{err_info}")
|
||||
|
||||
|
@ -1,23 +1,21 @@
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
from ..post import Post
|
||||
from ..types import ApiError, RawPost, Target
|
||||
from ..utils import SchedulerConfig
|
||||
from .platform import NewMessage
|
||||
from ..utils import SchedulerConfig
|
||||
from ..types import Target, RawPost, ApiError
|
||||
|
||||
|
||||
class NcmSchedConf(SchedulerConfig):
|
||||
|
||||
name = "music.163.com"
|
||||
schedule_type = "interval"
|
||||
schedule_setting = {"minutes": 1}
|
||||
|
||||
|
||||
class NcmArtist(NewMessage):
|
||||
|
||||
categories = {}
|
||||
platform_name = "ncm-artist"
|
||||
enable_tag = False
|
||||
@ -29,11 +27,9 @@ class NcmArtist(NewMessage):
|
||||
parse_target_promot = "请输入歌手主页(包含数字ID)的链接"
|
||||
|
||||
@classmethod
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||
res = await client.get(
|
||||
"https://music.163.com/api/artist/albums/{}".format(target),
|
||||
f"https://music.163.com/api/artist/albums/{target}",
|
||||
headers={"Referer": "https://music.163.com/"},
|
||||
)
|
||||
res_data = res.json()
|
||||
@ -45,16 +41,14 @@ class NcmArtist(NewMessage):
|
||||
async def parse_target(cls, target_text: str) -> Target:
|
||||
if re.match(r"^\d+$", target_text):
|
||||
return Target(target_text)
|
||||
elif match := re.match(
|
||||
r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text
|
||||
):
|
||||
elif match := re.match(r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text):
|
||||
return Target(match.group(1))
|
||||
else:
|
||||
raise cls.ParseTargetException()
|
||||
|
||||
async def get_sub_list(self, target: Target) -> list[RawPost]:
|
||||
res = await self.client.get(
|
||||
"https://music.163.com/api/artist/albums/{}".format(target),
|
||||
f"https://music.163.com/api/artist/albums/{target}",
|
||||
headers={"Referer": "https://music.163.com/"},
|
||||
)
|
||||
res_data = res.json()
|
||||
@ -74,13 +68,10 @@ class NcmArtist(NewMessage):
|
||||
target_name = raw_post["artist"]["name"]
|
||||
pics = [raw_post["picUrl"]]
|
||||
url = "https://music.163.com/#/album?id={}".format(raw_post["id"])
|
||||
return Post(
|
||||
"ncm-artist", text=text, url=url, pics=pics, target_name=target_name
|
||||
)
|
||||
return Post("ncm-artist", text=text, url=url, pics=pics, target_name=target_name)
|
||||
|
||||
|
||||
class NcmRadio(NewMessage):
|
||||
|
||||
categories = {}
|
||||
platform_name = "ncm-radio"
|
||||
enable_tag = False
|
||||
@ -92,9 +83,7 @@ class NcmRadio(NewMessage):
|
||||
parse_target_promot = "请输入主播电台主页(包含数字ID)的链接"
|
||||
|
||||
@classmethod
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||
res = await client.post(
|
||||
"http://music.163.com/api/dj/program/byradio",
|
||||
headers={"Referer": "https://music.163.com/"},
|
||||
@ -109,9 +98,7 @@ class NcmRadio(NewMessage):
|
||||
async def parse_target(cls, target_text: str) -> Target:
|
||||
if re.match(r"^\d+$", target_text):
|
||||
return Target(target_text)
|
||||
elif match := re.match(
|
||||
r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text
|
||||
):
|
||||
elif match := re.match(r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text):
|
||||
return Target(match.group(1))
|
||||
else:
|
||||
raise cls.ParseTargetException()
|
||||
|
@ -1,29 +1,32 @@
|
||||
import json
|
||||
import ssl
|
||||
import json
|
||||
import time
|
||||
import typing
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Collection, Optional, Type
|
||||
from collections.abc import Collection
|
||||
|
||||
import httpx
|
||||
from httpx import AsyncClient
|
||||
from nonebot.log import logger
|
||||
from nonebot_plugin_saa import PlatformTarget
|
||||
|
||||
from ..plugin_config import plugin_config
|
||||
from ..post import Post
|
||||
from ..types import Category, RawPost, Tag, Target, UserSubInfo
|
||||
from ..plugin_config import plugin_config
|
||||
from ..utils import ProcessContext, SchedulerConfig
|
||||
from ..types import Tag, Target, RawPost, Category, UserSubInfo
|
||||
|
||||
|
||||
class CategoryNotSupport(Exception):
|
||||
"raise in get_category, when you know the category of the post but don't want to support it or don't support its parsing yet"
|
||||
"""raise in get_category, when you know the category of the post
|
||||
but don't want to support it or don't support its parsing yet
|
||||
"""
|
||||
|
||||
|
||||
class CategoryNotRecognize(Exception):
|
||||
"raise in get_category, when you don't know the category of post"
|
||||
"""raise in get_category, when you don't know the category of post"""
|
||||
|
||||
|
||||
class RegistryMeta(type):
|
||||
@ -42,7 +45,6 @@ class RegistryMeta(type):
|
||||
|
||||
|
||||
class PlatformMeta(RegistryMeta):
|
||||
|
||||
categories: dict[Category, str]
|
||||
store: dict[Target, Any]
|
||||
|
||||
@ -60,8 +62,7 @@ class PlatformABCMeta(PlatformMeta, ABC):
|
||||
|
||||
|
||||
class Platform(metaclass=PlatformABCMeta, base=True):
|
||||
|
||||
scheduler: Type[SchedulerConfig]
|
||||
scheduler: type[SchedulerConfig]
|
||||
ctx: ProcessContext
|
||||
is_common: bool
|
||||
enabled: bool
|
||||
@ -70,16 +71,14 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
||||
categories: dict[Category, str]
|
||||
enable_tag: bool
|
||||
platform_name: str
|
||||
parse_target_promot: Optional[str] = None
|
||||
registry: list[Type["Platform"]]
|
||||
parse_target_promot: str | None = None
|
||||
registry: list[type["Platform"]]
|
||||
client: AsyncClient
|
||||
reverse_category: dict[str, Category]
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
@ -95,11 +94,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
||||
return await self.fetch_new_post(target, users)
|
||||
except httpx.RequestError as err:
|
||||
if plugin_config.bison_show_network_warning:
|
||||
logger.warning(
|
||||
"network connection error: {}, url: {}".format(
|
||||
type(err), err.request.url
|
||||
)
|
||||
)
|
||||
logger.warning(f"network connection error: {type(err)}, url: {err.request.url}")
|
||||
return []
|
||||
except ssl.SSLError as err:
|
||||
if plugin_config.bison_show_network_warning:
|
||||
@ -130,7 +125,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
||||
return Target(target_string)
|
||||
|
||||
@abstractmethod
|
||||
def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]:
|
||||
def get_tags(self, raw_post: RawPost) -> Collection[Tag] | None:
|
||||
"Return Tag list of given RawPost"
|
||||
|
||||
@classmethod
|
||||
@ -201,9 +196,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
||||
) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
res: list[tuple[PlatformTarget, list[Post]]] = []
|
||||
for user, cats, required_tags in users:
|
||||
user_raw_post = await self.filter_user_custom(
|
||||
new_posts, cats, required_tags
|
||||
)
|
||||
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
|
||||
user_post: list[Post] = []
|
||||
for raw_post in user_raw_post:
|
||||
user_post.append(await self.do_parse(raw_post))
|
||||
@ -211,7 +204,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
||||
return res
|
||||
|
||||
@abstractmethod
|
||||
def get_category(self, post: RawPost) -> Optional[Category]:
|
||||
def get_category(self, post: RawPost) -> Category | None:
|
||||
"Return category of given Rawpost"
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -221,7 +214,7 @@ class MessageProcess(Platform, abstract=True):
|
||||
|
||||
def __init__(self, ctx: ProcessContext, client: AsyncClient):
|
||||
super().__init__(ctx, client)
|
||||
self.parse_cache: dict[Any, Post] = dict()
|
||||
self.parse_cache: dict[Any, Post] = {}
|
||||
|
||||
@abstractmethod
|
||||
def get_id(self, post: RawPost) -> Any:
|
||||
@ -246,7 +239,7 @@ class MessageProcess(Platform, abstract=True):
|
||||
"Get post list of the given target"
|
||||
|
||||
@abstractmethod
|
||||
def get_date(self, post: RawPost) -> Optional[int]:
|
||||
def get_date(self, post: RawPost) -> int | None:
|
||||
"Get post timestamp and return, return None if can't get the time"
|
||||
|
||||
async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]:
|
||||
@ -286,9 +279,7 @@ class NewMessage(MessageProcess, abstract=True):
|
||||
inited: bool
|
||||
exists_posts: set[Any]
|
||||
|
||||
async def filter_common_with_diff(
|
||||
self, target: Target, raw_post_list: list[RawPost]
|
||||
) -> list[RawPost]:
|
||||
async def filter_common_with_diff(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]:
|
||||
filtered_post = await self.filter_common(raw_post_list)
|
||||
store = self.get_stored_data(target) or self.MessageStorage(False, set())
|
||||
res = []
|
||||
@ -297,11 +288,7 @@ class NewMessage(MessageProcess, abstract=True):
|
||||
for raw_post in filtered_post:
|
||||
post_id = self.get_id(raw_post)
|
||||
store.exists_posts.add(post_id)
|
||||
logger.info(
|
||||
"init {}-{} with {}".format(
|
||||
self.platform_name, target, store.exists_posts
|
||||
)
|
||||
)
|
||||
logger.info(f"init {self.platform_name}-{target} with {store.exists_posts}")
|
||||
store.inited = True
|
||||
else:
|
||||
for raw_post in filtered_post:
|
||||
@ -400,12 +387,11 @@ class SimplePost(MessageProcess, abstract=True):
|
||||
return res
|
||||
|
||||
|
||||
def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
|
||||
|
||||
def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]:
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
class NoTargetGroup(Platform, abstract=True):
|
||||
platform_list: list[Type[Platform]]
|
||||
platform_list: list[type[Platform]]
|
||||
platform_obj_list: list[Platform]
|
||||
|
||||
DUMMY_STR = "_DUMMY"
|
||||
@ -418,24 +404,18 @@ def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
|
||||
|
||||
for platform in platform_list:
|
||||
if platform.has_target:
|
||||
raise RuntimeError(
|
||||
"Platform {} should have no target".format(platform.name)
|
||||
)
|
||||
raise RuntimeError(f"Platform {platform.name} should have no target")
|
||||
if name == DUMMY_STR:
|
||||
name = platform.name
|
||||
elif name != platform.name:
|
||||
raise RuntimeError("Platform name for {} not fit".format(platform_name))
|
||||
raise RuntimeError(f"Platform name for {platform_name} not fit")
|
||||
platform_category_key_set = set(platform.categories.keys())
|
||||
if platform_category_key_set & categories_keys:
|
||||
raise RuntimeError(
|
||||
"Platform categories for {} duplicate".format(platform_name)
|
||||
)
|
||||
raise RuntimeError(f"Platform categories for {platform_name} duplicate")
|
||||
categories_keys |= platform_category_key_set
|
||||
categories.update(platform.categories)
|
||||
if platform.scheduler != scheduler:
|
||||
raise RuntimeError(
|
||||
"Platform scheduler for {} not fit".format(platform_name)
|
||||
)
|
||||
raise RuntimeError(f"Platform scheduler for {platform_name} not fit")
|
||||
|
||||
def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient):
|
||||
Platform.__init__(self, ctx, client)
|
||||
@ -444,15 +424,13 @@ def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
|
||||
self.platform_obj_list.append(platform_class(ctx, client))
|
||||
|
||||
def __str__(self: "NoTargetGroup") -> str:
|
||||
return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]"
|
||||
return "[" + " ".join(x.name for x in self.platform_list) + "]"
|
||||
|
||||
@classmethod
|
||||
async def get_target_name(cls, client: AsyncClient, target: Target):
|
||||
return await platform_list[0].get_target_name(client, target)
|
||||
|
||||
async def fetch_new_post(
|
||||
self: "NoTargetGroup", target: Target, users: list[UserSubInfo]
|
||||
):
|
||||
async def fetch_new_post(self: "NoTargetGroup", target: Target, users: list[UserSubInfo]):
|
||||
res = defaultdict(list)
|
||||
for platform in self.platform_obj_list:
|
||||
platform_res = await platform.fetch_new_post(target=target, users=users)
|
||||
|
@ -1,10 +1,10 @@
|
||||
import calendar
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import feedparser
|
||||
from bs4 import BeautifulSoup as bs
|
||||
from httpx import AsyncClient
|
||||
from bs4 import BeautifulSoup as bs
|
||||
|
||||
from ..post import Post
|
||||
from ..types import RawPost, Target
|
||||
@ -20,7 +20,6 @@ class RssSchedConf(SchedulerConfig):
|
||||
|
||||
|
||||
class Rss(NewMessage):
|
||||
|
||||
categories = {}
|
||||
enable_tag = False
|
||||
platform_name = "rss"
|
||||
@ -31,9 +30,7 @@ class Rss(NewMessage):
|
||||
has_target = True
|
||||
|
||||
@classmethod
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||
res = await client.get(target, timeout=10.0)
|
||||
feed = feedparser.parse(res.text)
|
||||
return feed["feed"]["title"]
|
||||
@ -69,7 +66,7 @@ class Rss(NewMessage):
|
||||
else:
|
||||
text = f"{title}\n\n{desc}"
|
||||
|
||||
pics = list(map(lambda x: x.attrs["src"], soup("img")))
|
||||
pics = [x.attrs["src"] for x in soup("img")]
|
||||
if raw_post.get("media_content"):
|
||||
for media in raw_post["media_content"]:
|
||||
if media.get("medium") == "image" and media.get("url"):
|
||||
|
@ -1,17 +1,16 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
import json
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from bs4 import BeautifulSoup as bs
|
||||
from httpx import AsyncClient
|
||||
from nonebot.log import logger
|
||||
from bs4 import BeautifulSoup as bs
|
||||
|
||||
from ..post import Post
|
||||
from ..types import *
|
||||
from ..utils import SchedulerConfig, http_client
|
||||
from .platform import NewMessage
|
||||
from ..utils import SchedulerConfig, http_client
|
||||
from ..types import Tag, Target, RawPost, ApiError, Category
|
||||
|
||||
|
||||
class WeiboSchedConf(SchedulerConfig):
|
||||
@ -21,7 +20,6 @@ class WeiboSchedConf(SchedulerConfig):
|
||||
|
||||
|
||||
class Weibo(NewMessage):
|
||||
|
||||
categories = {
|
||||
1: "转发",
|
||||
2: "视频",
|
||||
@ -38,13 +36,9 @@ class Weibo(NewMessage):
|
||||
parse_target_promot = "请输入用户主页(包含数字UID)的链接"
|
||||
|
||||
@classmethod
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||
param = {"containerid": "100505" + target}
|
||||
res = await client.get(
|
||||
"https://m.weibo.cn/api/container/getIndex", params=param
|
||||
)
|
||||
res = await client.get("https://m.weibo.cn/api/container/getIndex", params=param)
|
||||
res_dict = json.loads(res.text)
|
||||
if res_dict.get("ok") == 1:
|
||||
return res_dict["data"]["userInfo"]["screen_name"]
|
||||
@ -63,13 +57,14 @@ class Weibo(NewMessage):
|
||||
|
||||
async def get_sub_list(self, target: Target) -> list[RawPost]:
|
||||
params = {"containerid": "107603" + target}
|
||||
res = await self.client.get(
|
||||
"https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0
|
||||
)
|
||||
res = await self.client.get("https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0)
|
||||
res_data = json.loads(res.text)
|
||||
if not res_data["ok"] and res_data["msg"] != "这里还没有内容":
|
||||
raise ApiError(res.request.url)
|
||||
custom_filter: Callable[[RawPost], bool] = lambda d: d["card_type"] == 9
|
||||
|
||||
def custom_filter(d: RawPost) -> bool:
|
||||
return d["card_type"] == 9
|
||||
|
||||
return list(filter(custom_filter, res_data["data"]["cards"]))
|
||||
|
||||
def get_id(self, post: RawPost) -> Any:
|
||||
@ -79,44 +74,32 @@ class Weibo(NewMessage):
|
||||
return raw_post["card_type"] == 9
|
||||
|
||||
def get_date(self, raw_post: RawPost) -> float:
|
||||
created_time = datetime.strptime(
|
||||
raw_post["mblog"]["created_at"], "%a %b %d %H:%M:%S %z %Y"
|
||||
)
|
||||
created_time = datetime.strptime(raw_post["mblog"]["created_at"], "%a %b %d %H:%M:%S %z %Y")
|
||||
return created_time.timestamp()
|
||||
|
||||
def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]:
|
||||
def get_tags(self, raw_post: RawPost) -> list[Tag] | None:
|
||||
"Return Tag list of given RawPost"
|
||||
text = raw_post["mblog"]["text"]
|
||||
soup = bs(text, "html.parser")
|
||||
res = list(
|
||||
map(
|
||||
lambda x: x[1:-1],
|
||||
filter(
|
||||
lambda s: s[0] == "#" and s[-1] == "#",
|
||||
map(lambda x: x.text, soup.find_all("span", class_="surl-text")),
|
||||
),
|
||||
res = [
|
||||
x[1:-1]
|
||||
for x in filter(
|
||||
lambda s: s[0] == "#" and s[-1] == "#",
|
||||
(x.text for x in soup.find_all("span", class_="surl-text")),
|
||||
)
|
||||
)
|
||||
super_topic_img = soup.find(
|
||||
"img", src=re.compile(r"timeline_card_small_super_default")
|
||||
)
|
||||
]
|
||||
super_topic_img = soup.find("img", src=re.compile(r"timeline_card_small_super_default"))
|
||||
if super_topic_img:
|
||||
try:
|
||||
res.append(
|
||||
super_topic_img.parent.parent.find("span", class_="surl-text").text # type: ignore
|
||||
+ "超话"
|
||||
)
|
||||
except:
|
||||
logger.info("super_topic extract error: {}".format(text))
|
||||
res.append(super_topic_img.parent.parent.find("span", class_="surl-text").text + "超话") # type: ignore
|
||||
except Exception:
|
||||
logger.info(f"super_topic extract error: {text}")
|
||||
return res
|
||||
|
||||
def get_category(self, raw_post: RawPost) -> Category:
|
||||
if raw_post["mblog"].get("retweeted_status"):
|
||||
return Category(1)
|
||||
elif (
|
||||
raw_post["mblog"].get("page_info")
|
||||
and raw_post["mblog"]["page_info"].get("type") == "video"
|
||||
):
|
||||
elif raw_post["mblog"].get("page_info") and raw_post["mblog"]["page_info"].get("type") == "video":
|
||||
return Category(2)
|
||||
elif raw_post["mblog"].get("pics"):
|
||||
return Category(3)
|
||||
@ -129,7 +112,8 @@ class Weibo(NewMessage):
|
||||
|
||||
async def parse(self, raw_post: RawPost) -> Post:
|
||||
header = {
|
||||
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
|
||||
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,"
|
||||
"*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
|
||||
"accept-language": "zh-CN,zh;q=0.9",
|
||||
"authority": "m.weibo.cn",
|
||||
"cache-control": "max-age=0",
|
||||
@ -147,26 +131,16 @@ class Weibo(NewMessage):
|
||||
retweeted = True
|
||||
pic_num = info["retweeted_status"]["pic_num"] if retweeted else info["pic_num"]
|
||||
if info["isLongText"] or pic_num > 9:
|
||||
res = await self.client.get(
|
||||
"https://m.weibo.cn/detail/{}".format(info["mid"]), headers=header
|
||||
)
|
||||
res = await self.client.get(f"https://m.weibo.cn/detail/{info['mid']}", headers=header)
|
||||
try:
|
||||
match = re.search(r'"status": ([\s\S]+),\s+"call"', res.text)
|
||||
assert match
|
||||
full_json_text = match.group(1)
|
||||
info = json.loads(full_json_text)
|
||||
except:
|
||||
logger.info(
|
||||
"detail message error: https://m.weibo.cn/detail/{}".format(
|
||||
info["mid"]
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.info(f"detail message error: https://m.weibo.cn/detail/{info['mid']}")
|
||||
parsed_text = self._get_text(info["text"])
|
||||
raw_pics_list = (
|
||||
info["retweeted_status"].get("pics", [])
|
||||
if retweeted
|
||||
else info.get("pics", [])
|
||||
)
|
||||
raw_pics_list = info["retweeted_status"].get("pics", []) if retweeted else info.get("pics", [])
|
||||
pic_urls = [img["large"]["url"] for img in raw_pics_list]
|
||||
pics = []
|
||||
for pic_url in pic_urls:
|
||||
@ -174,7 +148,7 @@ class Weibo(NewMessage):
|
||||
res = await client.get(pic_url)
|
||||
res.raise_for_status()
|
||||
pics.append(res.content)
|
||||
detail_url = "https://weibo.com/{}/{}".format(info["user"]["id"], info["bid"])
|
||||
detail_url = f"https://weibo.com/{info['user']['id']}/{info['bid']}"
|
||||
# return parsed_text, detail_url, pic_urls
|
||||
return Post(
|
||||
"weibo",
|
||||
|
@ -1,11 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
import nonebot
|
||||
from pydantic import BaseSettings
|
||||
|
||||
|
||||
class PlugConfig(BaseSettings):
|
||||
|
||||
bison_config_path: str = ""
|
||||
bison_use_pic: bool = False
|
||||
bison_init_filter: bool = True
|
||||
@ -17,8 +14,12 @@ class PlugConfig(BaseSettings):
|
||||
bison_use_pic_merge: int = 0 # 多图片时启用图片合并转发(仅限群)
|
||||
# 0:不启用;1:首条消息单独发送,剩余照片合并转发;2以及以上:所有消息全部合并转发
|
||||
bison_resend_times: int = 0
|
||||
bison_proxy: Optional[str]
|
||||
bison_ua: str = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.103 Safari/537.36"
|
||||
bison_proxy: str | None
|
||||
bison_ua: str = (
|
||||
"Mozilla/5.0 (X11; Linux x86_64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/51.0.2704.103 Safari/537.36"
|
||||
)
|
||||
bison_show_network_warning: bool = True
|
||||
|
||||
class Config:
|
||||
|
@ -1,3 +1,3 @@
|
||||
from .post import Post
|
||||
|
||||
__all__ = ["Post", "CustomPost"]
|
||||
__all__ = ["Post"]
|
||||
|
@ -1,7 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from functools import reduce
|
||||
from typing import Optional
|
||||
from abc import abstractmethod
|
||||
from dataclasses import field, dataclass
|
||||
|
||||
from nonebot_plugin_saa import MessageFactory, MessageSegmentFactory
|
||||
|
||||
@ -25,12 +24,12 @@ class BasePost:
|
||||
class OptionalMixin:
|
||||
# Because of https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
|
||||
|
||||
override_use_pic: Optional[bool] = None
|
||||
override_use_pic: bool | None = None
|
||||
compress: bool = False
|
||||
extra_msg: list[MessageFactory] = field(default_factory=list)
|
||||
|
||||
def _use_pic(self):
|
||||
if not self.override_use_pic is None:
|
||||
if self.override_use_pic is not None:
|
||||
return self.override_use_pic
|
||||
return plugin_config.bison_use_pic
|
||||
|
||||
@ -44,13 +43,9 @@ class AbstractPost(OptionalMixin, BasePost):
|
||||
msg_segments = await self.generate_text_messages()
|
||||
if msg_segments:
|
||||
if self.compress:
|
||||
msgs = [
|
||||
reduce(lambda x, y: x.append(y), msg_segments, MessageFactory([]))
|
||||
]
|
||||
msgs = [reduce(lambda x, y: x.append(y), msg_segments, MessageFactory([]))]
|
||||
else:
|
||||
msgs = list(
|
||||
map(lambda msg_segment: MessageFactory([msg_segment]), msg_segments)
|
||||
)
|
||||
msgs = [MessageFactory([msg_segment]) for msg_segment in msg_segments]
|
||||
else:
|
||||
msgs = []
|
||||
msgs.extend(self.extra_msg)
|
||||
|
@ -1,19 +1,17 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from dataclasses import field, dataclass
|
||||
|
||||
from nonebot.adapters.onebot.v11 import MessageSegment
|
||||
from nonebot.log import logger
|
||||
from nonebot.plugin import require
|
||||
from nonebot_plugin_saa import Image, MessageFactory, MessageSegmentFactory, Text
|
||||
from nonebot.adapters.onebot.v11 import MessageSegment
|
||||
from nonebot_plugin_saa import Text, Image, MessageSegmentFactory
|
||||
|
||||
from .abstract_post import AbstractPost, BasePost
|
||||
from .abstract_post import BasePost, AbstractPost
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CustomPost(BasePost):
|
||||
|
||||
ms_factories: list[MessageSegmentFactory] = field(default_factory=list)
|
||||
css_path: Optional[str] = None # 模板文件所用css路径
|
||||
css_path: str | None = None # 模板文件所用css路径
|
||||
|
||||
async def generate_text_messages(self) -> list[MessageSegmentFactory]:
|
||||
return self.ms_factories
|
||||
@ -31,15 +29,13 @@ class _CustomPost(BasePost):
|
||||
for message_segment in self.ms_factories:
|
||||
match message_segment:
|
||||
case Text(data={"text": text}):
|
||||
md += "{}<br>".format(text)
|
||||
md += f"{text}<br>"
|
||||
case Image(data={"image": image}):
|
||||
# use onebot v11 to convert image into url
|
||||
ob11_image = MessageSegment.image(image)
|
||||
md += "\n".format(ob11_image.data["file"])
|
||||
case _:
|
||||
logger.warning(
|
||||
"custom_post不支持处理类型:{}".format(type(message_segment))
|
||||
)
|
||||
logger.warning(f"custom_post不支持处理类型:{type(message_segment)}")
|
||||
continue
|
||||
|
||||
return md
|
||||
|
@ -1,30 +1,27 @@
|
||||
from dataclasses import dataclass, field
|
||||
from functools import reduce
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
from dataclasses import field, dataclass
|
||||
|
||||
import nonebot_plugin_saa as saa
|
||||
from nonebot.log import logger
|
||||
from nonebot_plugin_saa.utils import MessageFactory, MessageSegmentFactory
|
||||
from PIL import Image
|
||||
from nonebot.log import logger
|
||||
import nonebot_plugin_saa as saa
|
||||
from nonebot_plugin_saa.utils import MessageSegmentFactory
|
||||
|
||||
from ..utils import http_client, parse_text
|
||||
from .abstract_post import AbstractPost, BasePost, OptionalMixin
|
||||
from ..utils import parse_text, http_client
|
||||
from .abstract_post import BasePost, AbstractPost
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Post(BasePost):
|
||||
|
||||
target_type: str
|
||||
text: str
|
||||
url: Optional[str] = None
|
||||
target_name: Optional[str] = None
|
||||
pics: list[Union[str, bytes]] = field(default_factory=list)
|
||||
url: str | None = None
|
||||
target_name: str | None = None
|
||||
pics: list[str | bytes] = field(default_factory=list)
|
||||
|
||||
_message: Optional[list[MessageSegmentFactory]] = None
|
||||
_pic_message: Optional[list[MessageSegmentFactory]] = None
|
||||
_message: list[MessageSegmentFactory] | None = None
|
||||
_pic_message: list[MessageSegmentFactory] | None = None
|
||||
|
||||
async def _pic_url_to_image(self, data: Union[str, bytes]) -> Image.Image:
|
||||
async def _pic_url_to_image(self, data: str | bytes) -> Image.Image:
|
||||
pic_buffer = BytesIO()
|
||||
if isinstance(data, str):
|
||||
async with http_client() as client:
|
||||
@ -101,22 +98,19 @@ class _Post(BasePost):
|
||||
self.pics.insert(0, target_io.getvalue())
|
||||
|
||||
async def generate_text_messages(self) -> list[MessageSegmentFactory]:
|
||||
|
||||
if self._message is None:
|
||||
await self._pic_merge()
|
||||
msg_segments: list[MessageSegmentFactory] = []
|
||||
text = ""
|
||||
if self.text:
|
||||
text += "{}".format(
|
||||
self.text if len(self.text) < 500 else self.text[:500] + "..."
|
||||
)
|
||||
text += "{}".format(self.text if len(self.text) < 500 else self.text[:500] + "...")
|
||||
if text:
|
||||
text += "\n"
|
||||
text += "来源: {}".format(self.target_type)
|
||||
text += f"来源: {self.target_type}"
|
||||
if self.target_name:
|
||||
text += " {}".format(self.target_name)
|
||||
text += f" {self.target_name}"
|
||||
if self.url:
|
||||
text += " \n详情: {}".format(self.url)
|
||||
text += f" \n详情: {self.url}"
|
||||
msg_segments.append(saa.Text(text))
|
||||
for pic in self.pics:
|
||||
msg_segments.append(saa.Image(pic))
|
||||
@ -124,17 +118,16 @@ class _Post(BasePost):
|
||||
return self._message
|
||||
|
||||
async def generate_pic_messages(self) -> list[MessageSegmentFactory]:
|
||||
|
||||
if self._pic_message is None:
|
||||
await self._pic_merge()
|
||||
msg_segments: list[MessageSegmentFactory] = []
|
||||
text = ""
|
||||
if self.text:
|
||||
text += "{}".format(self.text)
|
||||
text += f"{self.text}"
|
||||
text += "\n"
|
||||
text += "来源: {}".format(self.target_type)
|
||||
text += f"来源: {self.target_type}"
|
||||
if self.target_name:
|
||||
text += " {}".format(self.target_name)
|
||||
text += f" {self.target_name}"
|
||||
msg_segments.append(await parse_text(text))
|
||||
if not self.target_type == "rss" and self.url:
|
||||
msg_segments.append(saa.Text(self.url))
|
||||
@ -149,14 +142,7 @@ class _Post(BasePost):
|
||||
self.target_name,
|
||||
self.text if len(self.text) < 500 else self.text[:500] + "...",
|
||||
self.url,
|
||||
", ".join(
|
||||
map(
|
||||
lambda x: "b64img"
|
||||
if isinstance(x, bytes) or x.startswith("base64")
|
||||
else x,
|
||||
self.pics,
|
||||
)
|
||||
),
|
||||
", ".join("b64img" if isinstance(x, bytes) or x.startswith("base64") else x for x in self.pics),
|
||||
)
|
||||
|
||||
|
||||
|
@ -1 +1,3 @@
|
||||
from .manager import *
|
||||
from .manager import init_scheduler, scheduler_dict, handle_delete_target, handle_insert_new_target
|
||||
|
||||
__all__ = ["init_scheduler", "handle_delete_target", "handle_insert_new_target", "scheduler_dict"]
|
||||
|
@ -1,18 +1,16 @@
|
||||
from typing import Type
|
||||
|
||||
from ..config import config
|
||||
from ..config.db_model import Target
|
||||
from ..platform import platform_manager
|
||||
from ..types import Target as T_Target
|
||||
from ..utils import SchedulerConfig
|
||||
from .scheduler import Scheduler
|
||||
from ..utils import SchedulerConfig
|
||||
from ..config.db_model import Target
|
||||
from ..types import Target as T_Target
|
||||
from ..platform import platform_manager
|
||||
|
||||
scheduler_dict: dict[Type[SchedulerConfig], Scheduler] = {}
|
||||
scheduler_dict: dict[type[SchedulerConfig], Scheduler] = {}
|
||||
|
||||
|
||||
async def init_scheduler():
|
||||
_schedule_class_dict: dict[Type[SchedulerConfig], list[Target]] = {}
|
||||
_schedule_class_platform_dict: dict[Type[SchedulerConfig], list[str]] = {}
|
||||
_schedule_class_dict: dict[type[SchedulerConfig], list[Target]] = {}
|
||||
_schedule_class_platform_dict: dict[type[SchedulerConfig], list[str]] = {}
|
||||
for platform in platform_manager.values():
|
||||
scheduler_config = platform.scheduler
|
||||
if not hasattr(scheduler_config, "name") or not scheduler_config.name:
|
||||
@ -33,9 +31,7 @@ async def init_scheduler():
|
||||
for target in target_list:
|
||||
schedulable_args.append((target.platform_name, T_Target(target.target)))
|
||||
platform_name_list = _schedule_class_platform_dict[scheduler_config]
|
||||
scheduler_dict[scheduler_config] = Scheduler(
|
||||
scheduler_config, schedulable_args, platform_name_list
|
||||
)
|
||||
scheduler_dict[scheduler_config] = Scheduler(scheduler_config, schedulable_args, platform_name_list)
|
||||
config.register_add_target_hook(handle_insert_new_target)
|
||||
config.register_delete_target_hook(handle_delete_target)
|
||||
|
||||
|
@ -1,14 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Type
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
from nonebot_plugin_saa.utils.exceptions import NoBotFound
|
||||
|
||||
from ..config import config
|
||||
from ..platform import platform_manager
|
||||
from ..send import send_msgs
|
||||
from ..types import Target
|
||||
from ..config import config
|
||||
from ..send import send_msgs
|
||||
from ..platform import platform_manager
|
||||
from ..utils import ProcessContext, SchedulerConfig
|
||||
|
||||
|
||||
@ -20,12 +19,11 @@ class Schedulable:
|
||||
|
||||
|
||||
class Scheduler:
|
||||
|
||||
schedulable_list: list[Schedulable]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler_config: Type[SchedulerConfig],
|
||||
scheduler_config: type[SchedulerConfig],
|
||||
schedulables: list[tuple[str, Target]],
|
||||
platform_name_list: list[str],
|
||||
):
|
||||
@ -37,15 +35,12 @@ class Scheduler:
|
||||
self.scheduler_config_obj = self.scheduler_config()
|
||||
self.schedulable_list = []
|
||||
for platform_name, target in schedulables:
|
||||
self.schedulable_list.append(
|
||||
Schedulable(
|
||||
platform_name=platform_name, target=target, current_weight=0
|
||||
)
|
||||
)
|
||||
self.schedulable_list.append(Schedulable(platform_name=platform_name, target=target, current_weight=0))
|
||||
self.platform_name_list = platform_name_list
|
||||
self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值
|
||||
logger.info(
|
||||
f"register scheduler for {self.name} with {self.scheduler_config.schedule_type} {self.scheduler_config.schedule_setting}"
|
||||
f"register scheduler for {self.name} with "
|
||||
f"{self.scheduler_config.schedule_type} {self.scheduler_config.schedule_setting}"
|
||||
)
|
||||
scheduler.add_job(
|
||||
self.exec_fetch,
|
||||
@ -53,7 +48,7 @@ class Scheduler:
|
||||
**self.scheduler_config.schedule_setting,
|
||||
)
|
||||
|
||||
async def get_next_schedulable(self) -> Optional[Schedulable]:
|
||||
async def get_next_schedulable(self) -> Schedulable | None:
|
||||
if not self.schedulable_list:
|
||||
return None
|
||||
cur_weight = await config.get_current_weight_val(self.platform_name_list)
|
||||
@ -61,16 +56,9 @@ class Scheduler:
|
||||
self.pre_weight_val = 0
|
||||
cur_max_schedulable = None
|
||||
for schedulable in self.schedulable_list:
|
||||
schedulable.current_weight += cur_weight[
|
||||
f"{schedulable.platform_name}-{schedulable.target}"
|
||||
]
|
||||
weight_sum += cur_weight[
|
||||
f"{schedulable.platform_name}-{schedulable.target}"
|
||||
]
|
||||
if (
|
||||
not cur_max_schedulable
|
||||
or cur_max_schedulable.current_weight < schedulable.current_weight
|
||||
):
|
||||
schedulable.current_weight += cur_weight[f"{schedulable.platform_name}-{schedulable.target}"]
|
||||
weight_sum += cur_weight[f"{schedulable.platform_name}-{schedulable.target}"]
|
||||
if not cur_max_schedulable or cur_max_schedulable.current_weight < schedulable.current_weight:
|
||||
cur_max_schedulable = schedulable
|
||||
assert cur_max_schedulable
|
||||
cur_max_schedulable.current_weight -= weight_sum
|
||||
@ -80,9 +68,7 @@ class Scheduler:
|
||||
context = ProcessContext()
|
||||
if not (schedulable := await self.get_next_schedulable()):
|
||||
return
|
||||
logger.trace(
|
||||
f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}"
|
||||
)
|
||||
logger.trace(f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}")
|
||||
send_userinfo_list = await config.get_platform_target_subscribers(
|
||||
schedulable.platform_name, schedulable.target
|
||||
)
|
||||
@ -92,9 +78,7 @@ class Scheduler:
|
||||
|
||||
try:
|
||||
platform_obj = platform_manager[schedulable.platform_name](context, client)
|
||||
to_send = await platform_obj.do_fetch_new_post(
|
||||
schedulable.target, send_userinfo_list
|
||||
)
|
||||
to_send = await platform_obj.do_fetch_new_post(schedulable.target, send_userinfo_list)
|
||||
except Exception as err:
|
||||
records = context.gen_req_records()
|
||||
for record in records:
|
||||
@ -107,7 +91,7 @@ class Scheduler:
|
||||
|
||||
for user, send_list in to_send:
|
||||
for send_post in send_list:
|
||||
logger.info("send to {}: {}".format(user, send_post))
|
||||
logger.info(f"send to {user}: {send_post}")
|
||||
try:
|
||||
await send_msgs(
|
||||
user,
|
||||
@ -119,19 +103,14 @@ class Scheduler:
|
||||
def insert_new_schedulable(self, platform_name: str, target: Target):
|
||||
self.pre_weight_val += 1000
|
||||
self.schedulable_list.append(Schedulable(platform_name, target, 1000))
|
||||
logger.info(
|
||||
f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})"
|
||||
)
|
||||
logger.info(f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})")
|
||||
|
||||
def delete_schedulable(self, platform_name, target: Target):
|
||||
if not self.schedulable_list:
|
||||
return
|
||||
to_find_idx = None
|
||||
for idx, schedulable in enumerate(self.schedulable_list):
|
||||
if (
|
||||
schedulable.platform_name == platform_name
|
||||
and schedulable.target == target
|
||||
):
|
||||
if schedulable.platform_name == platform_name and schedulable.target == target:
|
||||
to_find_idx = idx
|
||||
break
|
||||
if to_find_idx is not None:
|
||||
|
@ -1,21 +1,23 @@
|
||||
import importlib
|
||||
import json
|
||||
import time
|
||||
from functools import partial, wraps
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Coroutine, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
from functools import wraps, partial
|
||||
from collections.abc import Callable, Coroutine
|
||||
|
||||
from nonebot.log import logger
|
||||
|
||||
from ..config.subs_io import subscribes_export, subscribes_import
|
||||
from ..config.subs_io.nbesf_model import v1, v2
|
||||
from ..scheduler.manager import init_scheduler
|
||||
from ..config.subs_io.nbesf_model import v1, v2
|
||||
from ..config.subs_io import subscribes_export, subscribes_import
|
||||
|
||||
try:
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import anyio
|
||||
import click
|
||||
from typing_extensions import ParamSpec
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError("请使用 `pip install nonebot-bison[cli]` 安装所需依赖") from e
|
||||
|
||||
@ -65,9 +67,7 @@ def path_init(ctx, param, value):
|
||||
|
||||
|
||||
@cli.command(help="导出Nonebot Bison Exchangable Subcribes File", name="export")
|
||||
@click.option(
|
||||
"--path", "-p", default=None, callback=path_init, help="导出路径, 如果不指定,则默认为工作目录"
|
||||
)
|
||||
@click.option("--path", "-p", default=None, callback=path_init, help="导出路径, 如果不指定,则默认为工作目录")
|
||||
@click.option(
|
||||
"--format",
|
||||
default="json",
|
||||
@ -76,7 +76,6 @@ def path_init(ctx, param, value):
|
||||
)
|
||||
@run_async
|
||||
async def subs_export(path: Path, format: str):
|
||||
|
||||
await init_scheduler()
|
||||
|
||||
export_file = path / f"bison_subscribes_export_{int(time.time())}.{format}"
|
||||
@ -121,7 +120,6 @@ async def subs_export(path: Path, format: str):
|
||||
)
|
||||
@run_async
|
||||
async def subs_import(path: str, format: str):
|
||||
|
||||
await init_scheduler()
|
||||
|
||||
import_file_path = Path(path)
|
||||
|
@ -1,17 +1,16 @@
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from typing import Deque
|
||||
|
||||
from nonebot.adapters.onebot.v11.exception import ActionFailed
|
||||
from nonebot.log import logger
|
||||
from nonebot_plugin_saa import AggregatedMessageFactory, MessageFactory, PlatformTarget
|
||||
from nonebot.adapters.onebot.v11.exception import ActionFailed
|
||||
from nonebot_plugin_saa.utils.auto_select_bot import refresh_bots
|
||||
from nonebot_plugin_saa import MessageFactory, PlatformTarget, AggregatedMessageFactory
|
||||
|
||||
from .plugin_config import plugin_config
|
||||
|
||||
Sendable = MessageFactory | AggregatedMessageFactory
|
||||
|
||||
QUEUE: Deque[tuple[PlatformTarget, Sendable, int]] = deque()
|
||||
QUEUE: deque[tuple[PlatformTarget, Sendable, int]] = deque()
|
||||
|
||||
MESSGE_SEND_INTERVAL = 1.5
|
||||
|
||||
|
@ -1,21 +1,20 @@
|
||||
import contextlib
|
||||
from typing import Type, cast
|
||||
|
||||
from nonebot.adapters import Message, MessageTemplate
|
||||
from nonebot.typing import T_State
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.params import Arg, ArgPlainText
|
||||
from nonebot.typing import T_State
|
||||
from nonebot_plugin_saa import PlatformTarget, SupportedAdapters, Text
|
||||
from nonebot.adapters import Message, MessageTemplate
|
||||
from nonebot_plugin_saa import Text, PlatformTarget, SupportedAdapters
|
||||
|
||||
from ..apis import check_sub_target
|
||||
from ..config import config
|
||||
from ..config.db_config import SubscribeDupException
|
||||
from ..platform import Platform, platform_manager
|
||||
from ..types import Target
|
||||
from ..config import config
|
||||
from ..apis import check_sub_target
|
||||
from ..platform import Platform, platform_manager
|
||||
from ..config.db_config import SubscribeDupException
|
||||
from .utils import common_platform, ensure_user_info, gen_handle_cancel
|
||||
|
||||
|
||||
def do_add_sub(add_sub: Type[Matcher]):
|
||||
def do_add_sub(add_sub: type[Matcher]):
|
||||
handle_cancel = gen_handle_cancel(add_sub, "已中止订阅")
|
||||
|
||||
add_sub.handle()(ensure_user_info(add_sub))
|
||||
@ -25,12 +24,7 @@ def do_add_sub(add_sub: Type[Matcher]):
|
||||
state["_prompt"] = (
|
||||
"请输入想要订阅的平台,目前支持,请输入冒号左边的名称:\n"
|
||||
+ "".join(
|
||||
[
|
||||
"{}:{}\n".format(
|
||||
platform_name, platform_manager[platform_name].name
|
||||
)
|
||||
for platform_name in common_platform
|
||||
]
|
||||
[f"{platform_name}: {platform_manager[platform_name].name}\n" for platform_name in common_platform]
|
||||
)
|
||||
+ "要查看全部平台请输入:“全部”\n中止订阅过程请输入:“取消”"
|
||||
)
|
||||
@ -39,10 +33,7 @@ def do_add_sub(add_sub: Type[Matcher]):
|
||||
async def parse_platform(state: T_State, platform: str = ArgPlainText()) -> None:
|
||||
if platform == "全部":
|
||||
message = "全部平台\n" + "\n".join(
|
||||
[
|
||||
"{}:{}".format(platform_name, platform.name)
|
||||
for platform_name, platform in platform_manager.items()
|
||||
]
|
||||
[f"{platform_name}: {platform.name}" for platform_name, platform in platform_manager.items()]
|
||||
)
|
||||
await add_sub.reject(message)
|
||||
elif platform == "取消":
|
||||
@ -57,9 +48,7 @@ def do_add_sub(add_sub: Type[Matcher]):
|
||||
cur_platform = platform_manager[state["platform"]]
|
||||
if cur_platform.has_target:
|
||||
state["_prompt"] = (
|
||||
("1." + cur_platform.parse_target_promot + "\n2.")
|
||||
if cur_platform.parse_target_promot
|
||||
else ""
|
||||
("1." + cur_platform.parse_target_promot + "\n2.") if cur_platform.parse_target_promot else ""
|
||||
) + "请输入订阅用户的id\n查询id获取方法请回复:“查询”"
|
||||
else:
|
||||
matcher.set_arg("raw_id", None) # type: ignore
|
||||
@ -81,9 +70,7 @@ def do_add_sub(add_sub: Type[Matcher]):
|
||||
image = "https://s3.bmp.ovh/imgs/2022/03/ab3cc45d83bd3dd3.jpg"
|
||||
msg.overwrite(
|
||||
SupportedAdapters.onebot_v11,
|
||||
MessageSegment.share(
|
||||
url=url, title=title, content=content, image=image
|
||||
),
|
||||
MessageSegment.share(url=url, title=title, content=content, image=image),
|
||||
)
|
||||
await msg.reject()
|
||||
platform = platform_manager[state["platform"]]
|
||||
@ -99,14 +86,12 @@ def do_add_sub(add_sub: Type[Matcher]):
|
||||
await add_sub.reject("id输入错误")
|
||||
state["id"] = raw_id_text
|
||||
state["name"] = name
|
||||
except (Platform.ParseTargetException):
|
||||
except Platform.ParseTargetException:
|
||||
await add_sub.reject("不能从你的输入中提取出id,请检查你输入的内容是否符合预期")
|
||||
else:
|
||||
await add_sub.send(
|
||||
"即将订阅的用户为:{} {} {}\n如有错误请输入“取消”重新订阅".format(
|
||||
state["platform"], state["name"], state["id"]
|
||||
)
|
||||
)
|
||||
f"即将订阅的用户为:{state['platform']} {state['name']} {state['id']}\n如有错误请输入“取消”重新订阅"
|
||||
) # noqa: E501
|
||||
|
||||
@add_sub.handle()
|
||||
async def prepare_get_categories(matcher: Matcher, state: T_State):
|
||||
@ -125,7 +110,7 @@ def do_add_sub(add_sub: Type[Matcher]):
|
||||
if platform_manager[state["platform"]].categories:
|
||||
for cat in raw_cats_text.split():
|
||||
if cat not in platform_manager[state["platform"]].reverse_category:
|
||||
await add_sub.reject("不支持 {}".format(cat))
|
||||
await add_sub.reject(f"不支持 {cat}")
|
||||
res.append(platform_manager[state["platform"]].reverse_category[cat])
|
||||
state["cats"] = res
|
||||
|
||||
@ -135,14 +120,16 @@ def do_add_sub(add_sub: Type[Matcher]):
|
||||
matcher.set_arg("raw_tags", None) # type: ignore
|
||||
state["tags"] = []
|
||||
return
|
||||
state["_prompt"] = '请输入要订阅/屏蔽的标签(不含#号)\n多个标签请使用空格隔开\n订阅所有标签输入"全部标签"\n具体规则回复"详情"'
|
||||
state["_prompt"] = "请输入要订阅/屏蔽的标签(不含#号)\n" "多个标签请使用空格隔开\n" '订阅所有标签输入"全部标签"\n' '具体规则回复"详情"' # noqa: E501
|
||||
|
||||
@add_sub.got("raw_tags", MessageTemplate("{_prompt}"), [handle_cancel])
|
||||
async def parser_tags(state: T_State, raw_tags: Message = Arg()):
|
||||
raw_tags_text = raw_tags.extract_plain_text()
|
||||
if raw_tags_text == "详情":
|
||||
await add_sub.reject(
|
||||
"订阅标签直接输入标签内容\n屏蔽标签请在标签名称前添加~号\n详见https://nonebot-bison.netlify.app/usage/#%E5%B9%B3%E5%8F%B0%E8%AE%A2%E9%98%85%E6%A0%87%E7%AD%BE-tag"
|
||||
"订阅标签直接输入标签内容\n"
|
||||
"屏蔽标签请在标签名称前添加~号\n"
|
||||
"详见https://nonebot-bison.netlify.app/usage/#%E5%B9%B3%E5%8F%B0%E8%AE%A2%E9%98%85%E6%A0%87%E7%AD%BE-tag"
|
||||
)
|
||||
if raw_tags_text in ["全部标签", "全部", "全标签"]:
|
||||
state["tags"] = []
|
||||
@ -150,9 +137,7 @@ def do_add_sub(add_sub: Type[Matcher]):
|
||||
state["tags"] = raw_tags_text.split()
|
||||
|
||||
@add_sub.handle()
|
||||
async def add_sub_process(
|
||||
state: T_State, user: PlatformTarget = Arg("target_user_info")
|
||||
):
|
||||
async def add_sub_process(state: T_State, user: PlatformTarget = Arg("target_user_info")):
|
||||
try:
|
||||
await config.add_subscribe(
|
||||
user=user,
|
||||
|
@ -1,26 +1,22 @@
|
||||
from typing import Type
|
||||
|
||||
from nonebot.typing import T_State
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.params import Arg, EventPlainText
|
||||
from nonebot.typing import T_State
|
||||
from nonebot_plugin_saa import MessageFactory, PlatformTarget
|
||||
|
||||
from ..config import config
|
||||
from ..platform import platform_manager
|
||||
from ..types import Category
|
||||
from ..utils import parse_text
|
||||
from ..platform import platform_manager
|
||||
from .utils import ensure_user_info, gen_handle_cancel
|
||||
|
||||
|
||||
def do_del_sub(del_sub: Type[Matcher]):
|
||||
def do_del_sub(del_sub: type[Matcher]):
|
||||
handle_cancel = gen_handle_cancel(del_sub, "删除中止")
|
||||
|
||||
del_sub.handle()(ensure_user_info(del_sub))
|
||||
|
||||
@del_sub.handle()
|
||||
async def send_list(
|
||||
state: T_State, user_info: PlatformTarget = Arg("target_user_info")
|
||||
):
|
||||
async def send_list(state: T_State, user_info: PlatformTarget = Arg("target_user_info")):
|
||||
sub_list = await config.list_subscribe(user_info)
|
||||
if not sub_list:
|
||||
await del_sub.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅")
|
||||
@ -31,22 +27,10 @@ def do_del_sub(del_sub: Type[Matcher]):
|
||||
"platform_name": sub.target.platform_name,
|
||||
"target": sub.target.target,
|
||||
}
|
||||
res += "{} {} {} {}\n".format(
|
||||
index,
|
||||
sub.target.platform_name,
|
||||
sub.target.target_name,
|
||||
sub.target.target,
|
||||
)
|
||||
res += f"{index} {sub.target.platform_name} {sub.target.target_name} {sub.target.target}\n"
|
||||
platform = platform_manager[sub.target.platform_name]
|
||||
if platform.categories:
|
||||
res += " [{}]".format(
|
||||
", ".join(
|
||||
map(
|
||||
lambda x: platform.categories[Category(x)],
|
||||
sub.categories,
|
||||
)
|
||||
)
|
||||
)
|
||||
res += " [{}]".format(", ".join(platform.categories[Category(x)] for x in sub.categories))
|
||||
if platform.enable_tag:
|
||||
res += " {}".format(", ".join(sub.tags))
|
||||
res += "\n"
|
||||
@ -62,7 +46,7 @@ def do_del_sub(del_sub: Type[Matcher]):
|
||||
try:
|
||||
index = int(index_str)
|
||||
await config.del_subscribe(user_info, **state["sub_table"][index])
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
await del_sub.reject("删除错误")
|
||||
else:
|
||||
await del_sub.finish("删除成功")
|
||||
|
@ -1,17 +1,15 @@
|
||||
from typing import Type
|
||||
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.params import Arg
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot_plugin_saa import MessageFactory, PlatformTarget
|
||||
|
||||
from ..config import config
|
||||
from ..platform import platform_manager
|
||||
from ..types import Category
|
||||
from ..utils import parse_text
|
||||
from .utils import ensure_user_info
|
||||
from ..platform import platform_manager
|
||||
|
||||
|
||||
def do_query_sub(query_sub: Type[Matcher]):
|
||||
def do_query_sub(query_sub: type[Matcher]):
|
||||
query_sub.handle()(ensure_user_info(query_sub))
|
||||
|
||||
@query_sub.handle()
|
||||
@ -19,19 +17,10 @@ def do_query_sub(query_sub: Type[Matcher]):
|
||||
sub_list = await config.list_subscribe(user_info)
|
||||
res = "订阅的帐号为:\n"
|
||||
for sub in sub_list:
|
||||
res += "{} {} {}".format(
|
||||
# sub["target_type"], sub["target_name"], sub["target"]
|
||||
sub.target.platform_name,
|
||||
sub.target.target_name,
|
||||
sub.target.target,
|
||||
)
|
||||
res += f"{sub.target.platform_name} {sub.target.target_name} {sub.target.target}"
|
||||
platform = platform_manager[sub.target.platform_name]
|
||||
if platform.categories:
|
||||
res += " [{}]".format(
|
||||
", ".join(
|
||||
map(lambda x: platform.categories[Category(x)], sub.categories)
|
||||
)
|
||||
)
|
||||
res += " [{}]".format(", ".join(platform.categories[Category(x)] for x in sub.categories))
|
||||
if platform.enable_tag:
|
||||
res += " {}".format(", ".join(sub.tags))
|
||||
res += "\n"
|
||||
|
@ -1,13 +1,13 @@
|
||||
import contextlib
|
||||
from typing import Annotated, Type
|
||||
from typing import Annotated
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.params import Depends, EventPlainText, EventToMe
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot.rule import Rule
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.typing import T_State
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.permission import SUPERUSER
|
||||
from nonebot_plugin_saa import extract_target
|
||||
from nonebot.params import Depends, EventToMe, EventPlainText
|
||||
|
||||
from ..platform import platform_manager
|
||||
from ..plugin_config import plugin_config
|
||||
@ -31,7 +31,7 @@ common_platform = [
|
||||
]
|
||||
|
||||
|
||||
def gen_handle_cancel(matcher: Type[Matcher], message: str):
|
||||
def gen_handle_cancel(matcher: type[Matcher], message: str):
|
||||
async def _handle_cancel(text: Annotated[str, EventPlainText()]):
|
||||
if text == "取消":
|
||||
await matcher.finish(message)
|
||||
@ -39,12 +39,10 @@ def gen_handle_cancel(matcher: Type[Matcher], message: str):
|
||||
return Depends(_handle_cancel)
|
||||
|
||||
|
||||
def ensure_user_info(matcher: Type[Matcher]):
|
||||
def ensure_user_info(matcher: type[Matcher]):
|
||||
async def _check_user_info(state: T_State):
|
||||
if not state.get("target_user_info"):
|
||||
await matcher.finish(
|
||||
"No target_user_info set, this shouldn't happen, please issue"
|
||||
)
|
||||
await matcher.finish("No target_user_info set, this shouldn't happen, please issue")
|
||||
|
||||
return _check_user_info
|
||||
|
||||
|
@ -1,17 +1,16 @@
|
||||
import difflib
|
||||
import re
|
||||
import sys
|
||||
from typing import Union
|
||||
|
||||
import nonebot
|
||||
from bs4 import BeautifulSoup as bs
|
||||
from nonebot.log import default_format, logger
|
||||
from nonebot.plugin import require
|
||||
from nonebot_plugin_saa import Image, MessageSegmentFactory, Text
|
||||
from bs4 import BeautifulSoup as bs
|
||||
from nonebot.log import logger, default_format
|
||||
from nonebot_plugin_saa import Text, Image, MessageSegmentFactory
|
||||
|
||||
from ..plugin_config import plugin_config
|
||||
from .context import ProcessContext
|
||||
from .http import http_client
|
||||
from .context import ProcessContext
|
||||
from ..plugin_config import plugin_config
|
||||
from .scheduler_config import SchedulerConfig, scheduler
|
||||
|
||||
__all__ = [
|
||||
@ -30,7 +29,7 @@ class Singleton(type):
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
||||
cls._instances[cls] = super().__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
@ -63,7 +62,7 @@ def html_to_text(html: str, query_dict: dict = {}) -> str:
|
||||
|
||||
class Filter:
|
||||
def __init__(self) -> None:
|
||||
self.level: Union[int, str] = "DEBUG"
|
||||
self.level: int | str = "DEBUG"
|
||||
|
||||
def __call__(self, record):
|
||||
module_name: str = record["name"]
|
||||
@ -71,9 +70,7 @@ class Filter:
|
||||
if module:
|
||||
module_name = getattr(module, "__module_name__", module_name)
|
||||
record["name"] = module_name.split(".")[0]
|
||||
levelno = (
|
||||
logger.level(self.level).no if isinstance(self.level, str) else self.level
|
||||
)
|
||||
levelno = logger.level(self.level).no if isinstance(self.level, str) else self.level
|
||||
nonebot_warning_level = logger.level("WARNING").no
|
||||
return (
|
||||
record["level"].no >= levelno
|
||||
@ -94,11 +91,7 @@ if plugin_config.bison_filter_log:
|
||||
)
|
||||
config = nonebot.get_driver().config
|
||||
logger.success("Muted info & success from nonebot")
|
||||
default_filter.level = (
|
||||
("DEBUG" if config.debug else "INFO")
|
||||
if config.log_level is None
|
||||
else config.log_level
|
||||
)
|
||||
default_filter.level = ("DEBUG" if config.debug else "INFO") if config.log_level is None else config.log_level
|
||||
|
||||
|
||||
def text_similarity(str1, str2) -> float:
|
||||
|
@ -1,6 +1,6 @@
|
||||
from base64 import b64encode
|
||||
|
||||
from httpx import AsyncClient, Response
|
||||
from httpx import Response, AsyncClient
|
||||
|
||||
|
||||
class ProcessContext:
|
||||
@ -33,8 +33,13 @@ class ProcessContext:
|
||||
res = []
|
||||
for req in self.reqs:
|
||||
if self._should_print_content(req):
|
||||
log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} {req.text}"
|
||||
log_content = (
|
||||
f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} {req.text}"
|
||||
)
|
||||
else:
|
||||
log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} b64encoded: {b64encode(req.content[:50]).decode()}"
|
||||
log_content = (
|
||||
f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} "
|
||||
f"b64encoded: {b64encode(req.content[:50]).decode()}"
|
||||
)
|
||||
res.append(log_content)
|
||||
return res
|
||||
|
@ -1,5 +1,3 @@
|
||||
import functools
|
||||
|
||||
import httpx
|
||||
|
||||
from ..plugin_config import plugin_config
|
||||
@ -10,7 +8,6 @@ http_args = {
|
||||
http_headers = {"user-agent": plugin_config.bison_ua}
|
||||
|
||||
|
||||
@functools.wraps(httpx.AsyncClient)
|
||||
def http_client(*args, **kwargs):
|
||||
if headers := kwargs.get("headers"):
|
||||
new_headers = http_headers.copy()
|
||||
@ -18,7 +15,4 @@ def http_client(*args, **kwargs):
|
||||
kwargs["headers"] = new_headers
|
||||
else:
|
||||
kwargs["headers"] = http_headers
|
||||
return httpx.AsyncClient(*args, **kwargs)
|
||||
|
||||
|
||||
http_client = functools.partial(http_client, **http_args)
|
||||
return httpx.AsyncClient(*args, **kwargs, **http_args)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Literal, Type
|
||||
from typing import Literal
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
@ -7,7 +7,6 @@ from .http import http_client
|
||||
|
||||
|
||||
class SchedulerConfig:
|
||||
|
||||
schedule_type: Literal["date", "interval", "cron"]
|
||||
schedule_setting: dict
|
||||
name: str
|
||||
@ -25,9 +24,7 @@ class SchedulerConfig:
|
||||
return self.default_http_client
|
||||
|
||||
|
||||
def scheduler(
|
||||
schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict
|
||||
) -> Type[SchedulerConfig]:
|
||||
def scheduler(schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict) -> type[SchedulerConfig]:
|
||||
return type(
|
||||
"AnonymousScheduleConfig",
|
||||
(SchedulerConfig,),
|
||||
|
@ -85,7 +85,7 @@ line-length = 120
|
||||
target-version = "py310"
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
line-length = 118
|
||||
target-version = ["py310", "py311"]
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
|
Loading…
x
Reference in New Issue
Block a user