mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-08 04:43:00 +08:00
🎨 按ruff的检查调整程序代码
This commit is contained in:
parent
f232ce4c3e
commit
dba8f2a9cb
@ -6,25 +6,18 @@ require("nonebot_plugin_saa")
|
|||||||
|
|
||||||
import nonebot_plugin_saa
|
import nonebot_plugin_saa
|
||||||
|
|
||||||
from . import (
|
|
||||||
admin_page,
|
|
||||||
bootstrap,
|
|
||||||
config,
|
|
||||||
platform,
|
|
||||||
post,
|
|
||||||
scheduler,
|
|
||||||
send,
|
|
||||||
sub_manager,
|
|
||||||
types,
|
|
||||||
utils,
|
|
||||||
)
|
|
||||||
from .plugin_config import PlugConfig, plugin_config
|
from .plugin_config import PlugConfig, plugin_config
|
||||||
|
from . import post, send, types, utils, config, platform, bootstrap, scheduler, admin_page, sub_manager
|
||||||
|
|
||||||
__help__version__ = "0.7.3"
|
__help__version__ = "0.7.3"
|
||||||
nonebot_plugin_saa.enable_auto_select_bot()
|
nonebot_plugin_saa.enable_auto_select_bot()
|
||||||
|
|
||||||
__help__plugin__name__ = "nonebot_bison"
|
__help__plugin__name__ = "nonebot_bison"
|
||||||
__usage__ = f"本bot可以提供b站、微博等社交媒体的消息订阅,详情请查看本bot文档,或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,发送“查询订阅”或“删除订阅”管理订阅"
|
__usage__ = (
|
||||||
|
"本bot可以提供b站、微博等社交媒体的消息订阅,详情请查看本bot文档,"
|
||||||
|
f"或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,"
|
||||||
|
f"发送“查询订阅”或“删除订阅”管理订阅"
|
||||||
|
)
|
||||||
|
|
||||||
__supported_adapters__ = nonebot_plugin_saa.__plugin_meta__.supported_adapters
|
__supported_adapters__ = nonebot_plugin_saa.__plugin_meta__.supported_adapters
|
||||||
|
|
||||||
@ -41,6 +34,7 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"admin_page",
|
"admin_page",
|
||||||
|
"bootstrap",
|
||||||
"config",
|
"config",
|
||||||
"sub_manager",
|
"sub_manager",
|
||||||
"post",
|
"post",
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from nonebot import get_driver, on_command
|
|
||||||
from nonebot.adapters.onebot.v11 import Bot
|
|
||||||
from nonebot.adapters.onebot.v11.event import PrivateMessageEvent
|
|
||||||
from nonebot.drivers.fastapi import Driver
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.rule import to_me
|
from nonebot.rule import to_me
|
||||||
from nonebot.typing import T_State
|
from nonebot.typing import T_State
|
||||||
|
from nonebot import get_driver, on_command
|
||||||
|
from nonebot.drivers.fastapi import Driver
|
||||||
|
from nonebot.adapters.onebot.v11 import Bot
|
||||||
|
from nonebot.adapters.onebot.v11.event import PrivateMessageEvent
|
||||||
|
|
||||||
from ..plugin_config import plugin_config
|
|
||||||
from .api import router as api_router
|
from .api import router as api_router
|
||||||
|
from ..plugin_config import plugin_config
|
||||||
from .token_manager import token_manager as tm
|
from .token_manager import token_manager as tm
|
||||||
|
|
||||||
STATIC_PATH = (Path(__file__).parent / "dist").resolve()
|
STATIC_PATH = (Path(__file__).parent / "dist").resolve()
|
||||||
@ -28,11 +27,9 @@ def init_fastapi():
|
|||||||
class SinglePageApplication(StaticFiles):
|
class SinglePageApplication(StaticFiles):
|
||||||
def __init__(self, directory: os.PathLike, index="index.html"):
|
def __init__(self, directory: os.PathLike, index="index.html"):
|
||||||
self.index = index
|
self.index = index
|
||||||
super().__init__(
|
super().__init__(directory=directory, packages=None, html=True, check_dir=True)
|
||||||
directory=directory, packages=None, html=True, check_dir=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def lookup_path(self, path: str) -> tuple[str, Union[os.stat_result, None]]:
|
def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]:
|
||||||
full_path, stat_res = super().lookup_path(path)
|
full_path, stat_res = super().lookup_path(path)
|
||||||
if stat_res is None:
|
if stat_res is None:
|
||||||
return super().lookup_path(self.index)
|
return super().lookup_path(self.index)
|
||||||
@ -45,9 +42,7 @@ def init_fastapi():
|
|||||||
description="nonebot-bison webui and api",
|
description="nonebot-bison webui and api",
|
||||||
)
|
)
|
||||||
nonebot_app.include_router(api_router)
|
nonebot_app.include_router(api_router)
|
||||||
nonebot_app.mount(
|
nonebot_app.mount("/", SinglePageApplication(directory=static_path), name="bison-frontend")
|
||||||
"/", SinglePageApplication(directory=static_path), name="bison-frontend"
|
|
||||||
)
|
|
||||||
|
|
||||||
app = driver.server_app
|
app = driver.server_app
|
||||||
app.mount("/bison", nonebot_app, "nonebot-bison")
|
app.mount("/bison", nonebot_app, "nonebot-bison")
|
||||||
@ -63,10 +58,9 @@ def init_fastapi():
|
|||||||
if host in ["0.0.0.0", "127.0.0.1"]:
|
if host in ["0.0.0.0", "127.0.0.1"]:
|
||||||
host = "localhost"
|
host = "localhost"
|
||||||
logger.opt(colors=True).info(
|
logger.opt(colors=True).info(
|
||||||
f"Nonebot Bison frontend will be running at: "
|
f"Nonebot Bison frontend will be running at: " f"<b><u>http://{host}:{port}/bison</u></b>"
|
||||||
f"<b><u>http://{host}:{port}/bison</u></b>"
|
|
||||||
)
|
)
|
||||||
logger.opt(colors=True).info(f"该页面不能被直接访问,请私聊bot <b><u>后台管理</u></b> 以获取可访问地址")
|
logger.opt(colors=True).info("该页面不能被直接访问,请私聊bot <b><u>后台管理</u></b> 以获取可访问地址")
|
||||||
|
|
||||||
|
|
||||||
def register_get_token_handler():
|
def register_get_token_handler():
|
||||||
@ -93,6 +87,4 @@ if (STATIC_PATH / "index.html").exists():
|
|||||||
else:
|
else:
|
||||||
logger.warning("your driver is not fastapi, webui feature will be disabled")
|
logger.warning("your driver is not fastapi, webui feature will be disabled")
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning("Frontend file not found, please compile it or use docker or pypi version")
|
||||||
"Frontend file not found, please compile it or use docker or pypi version"
|
|
||||||
)
|
|
||||||
|
@ -1,35 +1,30 @@
|
|||||||
import nonebot
|
import nonebot
|
||||||
from fastapi import status
|
from fastapi import status
|
||||||
from fastapi.exceptions import HTTPException
|
|
||||||
from fastapi.param_functions import Depends
|
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from fastapi.security.oauth2 import OAuth2PasswordBearer
|
from fastapi.param_functions import Depends
|
||||||
|
from fastapi.exceptions import HTTPException
|
||||||
from nonebot_plugin_saa import TargetQQGroup
|
from nonebot_plugin_saa import TargetQQGroup
|
||||||
|
from fastapi.security.oauth2 import OAuth2PasswordBearer
|
||||||
from nonebot_plugin_saa.utils.auto_select_bot import get_bot
|
from nonebot_plugin_saa.utils.auto_select_bot import get_bot
|
||||||
|
|
||||||
from ..apis import check_sub_target
|
|
||||||
from ..config import (
|
|
||||||
NoSuchSubscribeException,
|
|
||||||
NoSuchTargetException,
|
|
||||||
NoSuchUserException,
|
|
||||||
config,
|
|
||||||
)
|
|
||||||
from ..config.db_config import SubscribeDupException
|
|
||||||
from ..platform import platform_manager
|
|
||||||
from ..types import Target as T_Target
|
|
||||||
from ..types import WeightConfig
|
from ..types import WeightConfig
|
||||||
from ..utils.get_bot import get_groups
|
from ..apis import check_sub_target
|
||||||
from .jwt import load_jwt, pack_jwt
|
from .jwt import load_jwt, pack_jwt
|
||||||
|
from ..types import Target as T_Target
|
||||||
|
from ..utils.get_bot import get_groups
|
||||||
|
from ..platform import platform_manager
|
||||||
from .token_manager import token_manager
|
from .token_manager import token_manager
|
||||||
|
from ..config.db_config import SubscribeDupException
|
||||||
|
from ..config import NoSuchUserException, NoSuchTargetException, NoSuchSubscribeException, config
|
||||||
from .types import (
|
from .types import (
|
||||||
AddSubscribeReq,
|
TokenResp,
|
||||||
GlobalConf,
|
GlobalConf,
|
||||||
PlatformConfig,
|
|
||||||
StatusResp,
|
StatusResp,
|
||||||
|
SubscribeResp,
|
||||||
|
PlatformConfig,
|
||||||
|
AddSubscribeReq,
|
||||||
SubscribeConfig,
|
SubscribeConfig,
|
||||||
SubscribeGroupDetail,
|
SubscribeGroupDetail,
|
||||||
SubscribeResp,
|
|
||||||
TokenResp,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["api"])
|
router = APIRouter(prefix="/api", tags=["api"])
|
||||||
@ -44,9 +39,7 @@ async def get_jwt_obj(token: str = Depends(oath_scheme)):
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
async def check_group_permission(
|
async def check_group_permission(groupNumber: int, token_obj: dict = Depends(get_jwt_obj)):
|
||||||
groupNumber: int, token_obj: dict = Depends(get_jwt_obj)
|
|
||||||
):
|
|
||||||
groups = token_obj["groups"]
|
groups = token_obj["groups"]
|
||||||
for group in groups:
|
for group in groups:
|
||||||
if int(groupNumber) == group["id"]:
|
if int(groupNumber) == group["id"]:
|
||||||
@ -95,15 +88,13 @@ async def auth(token: str) -> TokenResp:
|
|||||||
jwt_obj = {
|
jwt_obj = {
|
||||||
"id": qq,
|
"id": qq,
|
||||||
"type": "admin",
|
"type": "admin",
|
||||||
"groups": list(
|
"groups": [
|
||||||
map(
|
{
|
||||||
lambda info: {
|
|
||||||
"id": info["group_id"],
|
"id": info["group_id"],
|
||||||
"name": info["group_name"],
|
"name": info["group_name"],
|
||||||
},
|
}
|
||||||
await get_groups(),
|
for info in await get_groups()
|
||||||
)
|
],
|
||||||
),
|
|
||||||
}
|
}
|
||||||
ret_obj = TokenResp(
|
ret_obj = TokenResp(
|
||||||
type="admin",
|
type="admin",
|
||||||
@ -134,18 +125,16 @@ async def get_subs_info(jwt_obj: dict = Depends(get_jwt_obj)) -> SubscribeResp:
|
|||||||
for group in groups:
|
for group in groups:
|
||||||
group_id = group["id"]
|
group_id = group["id"]
|
||||||
raw_subs = await config.list_subscribe(TargetQQGroup(group_id=group_id))
|
raw_subs = await config.list_subscribe(TargetQQGroup(group_id=group_id))
|
||||||
subs = list(
|
subs = [
|
||||||
map(
|
SubscribeConfig(
|
||||||
lambda sub: SubscribeConfig(
|
|
||||||
platformName=sub.target.platform_name,
|
platformName=sub.target.platform_name,
|
||||||
targetName=sub.target.target_name,
|
targetName=sub.target.target_name,
|
||||||
cats=sub.categories,
|
cats=sub.categories,
|
||||||
tags=sub.tags,
|
tags=sub.tags,
|
||||||
target=sub.target.target,
|
target=sub.target.target,
|
||||||
),
|
|
||||||
raw_subs,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for sub in raw_subs
|
||||||
|
]
|
||||||
res[group_id] = SubscribeGroupDetail(name=group["name"], subscribes=subs)
|
res[group_id] = SubscribeGroupDetail(name=group["name"], subscribes=subs)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -174,9 +163,7 @@ async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp:
|
|||||||
@router.delete("/subs", dependencies=[Depends(check_group_permission)])
|
@router.delete("/subs", dependencies=[Depends(check_group_permission)])
|
||||||
async def del_group_sub(groupNumber: int, platformName: str, target: str):
|
async def del_group_sub(groupNumber: int, platformName: str, target: str):
|
||||||
try:
|
try:
|
||||||
await config.del_subscribe(
|
await config.del_subscribe(TargetQQGroup(group_id=groupNumber), target, platformName)
|
||||||
TargetQQGroup(group_id=groupNumber), target, platformName
|
|
||||||
)
|
|
||||||
except (NoSuchUserException, NoSuchSubscribeException):
|
except (NoSuchUserException, NoSuchSubscribeException):
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such user or subscribe")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such user or subscribe")
|
||||||
return StatusResp(ok=True, msg="")
|
return StatusResp(ok=True, msg="")
|
||||||
@ -204,13 +191,9 @@ async def get_weight_config():
|
|||||||
|
|
||||||
|
|
||||||
@router.put("/weight", dependencies=[Depends(check_is_superuser)])
|
@router.put("/weight", dependencies=[Depends(check_is_superuser)])
|
||||||
async def update_weigth_config(
|
async def update_weigth_config(platformName: str, target: str, weight_config: WeightConfig):
|
||||||
platformName: str, target: str, weight_config: WeightConfig
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
await config.update_time_weight_config(
|
await config.update_time_weight_config(T_Target(target), platformName, weight_config)
|
||||||
T_Target(target), platformName, weight_config
|
|
||||||
)
|
|
||||||
except NoSuchTargetException:
|
except NoSuchTargetException:
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such subscribe")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such subscribe")
|
||||||
return StatusResp(ok=True, msg="")
|
return StatusResp(ok=True, msg="")
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import datetime
|
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
from typing import Optional
|
import datetime
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
@ -16,8 +15,8 @@ def pack_jwt(obj: dict) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_jwt(token: str) -> Optional[dict]:
|
def load_jwt(token: str) -> dict | None:
|
||||||
try:
|
try:
|
||||||
return jwt.decode(token, _key, algorithms=["HS256"])
|
return jwt.decode(token, _key, algorithms=["HS256"])
|
||||||
except:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from expiringdict import ExpiringDict
|
from expiringdict import ExpiringDict
|
||||||
|
|
||||||
@ -9,7 +8,7 @@ class TokenManager:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.token_manager = ExpiringDict(max_len=100, max_age_seconds=60 * 10)
|
self.token_manager = ExpiringDict(max_len=100, max_age_seconds=60 * 10)
|
||||||
|
|
||||||
def get_user(self, token: str) -> Optional[tuple]:
|
def get_user(self, token: str) -> tuple | None:
|
||||||
res = self.token_manager.get(token)
|
res = self.token_manager.get(token)
|
||||||
assert res is None or isinstance(res, tuple)
|
assert res is None or isinstance(res, tuple)
|
||||||
return res
|
return res
|
||||||
|
@ -1,2 +1,4 @@
|
|||||||
from .db_config import config
|
from .db_config import config as config
|
||||||
from .utils import NoSuchSubscribeException, NoSuchTargetException, NoSuchUserException
|
from .utils import NoSuchUserException as NoSuchUserException
|
||||||
|
from .utils import NoSuchTargetException as NoSuchTargetException
|
||||||
|
from .utils import NoSuchSubscribeException as NoSuchSubscribeException
|
||||||
|
@ -1,20 +1,19 @@
|
|||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
import json
|
||||||
from datetime import datetime
|
|
||||||
from os import path
|
from os import path
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import DefaultDict, Literal, Mapping, TypedDict
|
from datetime import datetime
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Literal, TypedDict
|
||||||
|
|
||||||
import nonebot
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from tinydb import Query, TinyDB
|
from tinydb import Query, TinyDB
|
||||||
|
|
||||||
|
from ..utils import Singleton
|
||||||
|
from ..types import User, Target
|
||||||
from ..platform import platform_manager
|
from ..platform import platform_manager
|
||||||
from ..plugin_config import plugin_config
|
from ..plugin_config import plugin_config
|
||||||
from ..types import Target, User
|
from .utils import NoSuchUserException, NoSuchSubscribeException
|
||||||
from ..utils import Singleton
|
|
||||||
from .utils import NoSuchSubscribeException, NoSuchUserException
|
|
||||||
|
|
||||||
supported_target_type = platform_manager.keys()
|
supported_target_type = platform_manager.keys()
|
||||||
|
|
||||||
@ -89,17 +88,16 @@ class Config(metaclass=Singleton):
|
|||||||
self.target_user_cat_cache = {}
|
self.target_user_cat_cache = {}
|
||||||
self.target_user_tag_cache = {}
|
self.target_user_tag_cache = {}
|
||||||
self.target_list = {}
|
self.target_list = {}
|
||||||
self.next_index: DefaultDict[str, int] = defaultdict(lambda: 0)
|
self.next_index: defaultdict[str, int] = defaultdict(lambda: 0)
|
||||||
else:
|
else:
|
||||||
self.available = False
|
self.available = False
|
||||||
|
|
||||||
def add_subscribe(
|
def add_subscribe(self, user, user_type, target, target_name, target_type, cats, tags):
|
||||||
self, user, user_type, target, target_name, target_type, cats, tags
|
|
||||||
):
|
|
||||||
user_query = Query()
|
user_query = Query()
|
||||||
query = (user_query.user == user) & (user_query.user_type == user_type)
|
query = (user_query.user == user) & (user_query.user_type == user_type)
|
||||||
if user_data := self.user_target.get(query):
|
if user_data := self.user_target.get(query):
|
||||||
# update
|
# update
|
||||||
|
assert not isinstance(user_data, list)
|
||||||
subs: list = user_data.get("subs", [])
|
subs: list = user_data.get("subs", [])
|
||||||
subs.append(
|
subs.append(
|
||||||
{
|
{
|
||||||
@ -132,9 +130,8 @@ class Config(metaclass=Singleton):
|
|||||||
|
|
||||||
def list_subscribe(self, user, user_type) -> list[SubscribeContent]:
|
def list_subscribe(self, user, user_type) -> list[SubscribeContent]:
|
||||||
query = Query()
|
query = Query()
|
||||||
if user_sub := self.user_target.get(
|
if user_sub := self.user_target.get((query.user == user) & (query.user_type == user_type)):
|
||||||
(query.user == user) & (query.user_type == user_type)
|
assert not isinstance(user_sub, list)
|
||||||
):
|
|
||||||
return user_sub["subs"]
|
return user_sub["subs"]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -146,6 +143,7 @@ class Config(metaclass=Singleton):
|
|||||||
query = (user_query.user == user) & (user_query.user_type == user_type)
|
query = (user_query.user == user) & (user_query.user_type == user_type)
|
||||||
if not (query_res := self.user_target.get(query)):
|
if not (query_res := self.user_target.get(query)):
|
||||||
raise NoSuchUserException()
|
raise NoSuchUserException()
|
||||||
|
assert not isinstance(query_res, list)
|
||||||
subs = query_res.get("subs", [])
|
subs = query_res.get("subs", [])
|
||||||
for idx, sub in enumerate(subs):
|
for idx, sub in enumerate(subs):
|
||||||
if sub.get("target") == target and sub.get("target_type") == target_type:
|
if sub.get("target") == target and sub.get("target_type") == target_type:
|
||||||
@ -155,13 +153,12 @@ class Config(metaclass=Singleton):
|
|||||||
return
|
return
|
||||||
raise NoSuchSubscribeException()
|
raise NoSuchSubscribeException()
|
||||||
|
|
||||||
def update_subscribe(
|
def update_subscribe(self, user, user_type, target, target_name, target_type, cats, tags):
|
||||||
self, user, user_type, target, target_name, target_type, cats, tags
|
|
||||||
):
|
|
||||||
user_query = Query()
|
user_query = Query()
|
||||||
query = (user_query.user == user) & (user_query.user_type == user_type)
|
query = (user_query.user == user) & (user_query.user_type == user_type)
|
||||||
if user_data := self.user_target.get(query):
|
if user_data := self.user_target.get(query):
|
||||||
# update
|
# update
|
||||||
|
assert not isinstance(user_data, list)
|
||||||
subs: list = user_data.get("subs", [])
|
subs: list = user_data.get("subs", [])
|
||||||
find_flag = False
|
find_flag = False
|
||||||
for item in subs:
|
for item in subs:
|
||||||
@ -182,19 +179,13 @@ class Config(metaclass=Singleton):
|
|||||||
|
|
||||||
def update_send_cache(self):
|
def update_send_cache(self):
|
||||||
res = {target_type: defaultdict(list) for target_type in supported_target_type}
|
res = {target_type: defaultdict(list) for target_type in supported_target_type}
|
||||||
cat_res = {
|
cat_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
|
||||||
target_type: defaultdict(lambda: defaultdict(list))
|
tag_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
|
||||||
for target_type in supported_target_type
|
|
||||||
}
|
|
||||||
tag_res = {
|
|
||||||
target_type: defaultdict(lambda: defaultdict(list))
|
|
||||||
for target_type in supported_target_type
|
|
||||||
}
|
|
||||||
# res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
|
# res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
|
||||||
to_del = []
|
to_del = []
|
||||||
for user in self.user_target.all():
|
for user in self.user_target.all():
|
||||||
for sub in user.get("subs", []):
|
for sub in user.get("subs", []):
|
||||||
if not sub.get("target_type") in supported_target_type:
|
if sub.get("target_type") not in supported_target_type:
|
||||||
to_del.append(
|
to_del.append(
|
||||||
{
|
{
|
||||||
"user": user["user"],
|
"user": user["user"],
|
||||||
@ -204,36 +195,28 @@ class Config(metaclass=Singleton):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
res[sub["target_type"]][sub["target"]].append(
|
res[sub["target_type"]][sub["target"]].append(User(user["user"], user["user_type"]))
|
||||||
User(user["user"], user["user_type"])
|
cat_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[
|
||||||
)
|
"cats"
|
||||||
cat_res[sub["target_type"]][sub["target"]][
|
]
|
||||||
"{}-{}".format(user["user_type"], user["user"])
|
tag_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[
|
||||||
] = sub["cats"]
|
"tags"
|
||||||
tag_res[sub["target_type"]][sub["target"]][
|
]
|
||||||
"{}-{}".format(user["user_type"], user["user"])
|
|
||||||
] = sub["tags"]
|
|
||||||
self.target_user_cache = res
|
self.target_user_cache = res
|
||||||
self.target_user_cat_cache = cat_res
|
self.target_user_cat_cache = cat_res
|
||||||
self.target_user_tag_cache = tag_res
|
self.target_user_tag_cache = tag_res
|
||||||
for target_type in self.target_user_cache:
|
for target_type in self.target_user_cache:
|
||||||
self.target_list[target_type] = list(
|
self.target_list[target_type] = list(self.target_user_cache[target_type].keys())
|
||||||
self.target_user_cache[target_type].keys()
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Deleting {to_del}")
|
logger.info(f"Deleting {to_del}")
|
||||||
for d in to_del:
|
for d in to_del:
|
||||||
self.del_subscribe(**d)
|
self.del_subscribe(**d)
|
||||||
|
|
||||||
def get_sub_category(self, target_type, target, user_type, user):
|
def get_sub_category(self, target_type, target, user_type, user):
|
||||||
return self.target_user_cat_cache[target_type][target][
|
return self.target_user_cat_cache[target_type][target][f"{user_type}-{user}"]
|
||||||
"{}-{}".format(user_type, user)
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_sub_tags(self, target_type, target, user_type, user):
|
def get_sub_tags(self, target_type, target, user_type, user):
|
||||||
return self.target_user_tag_cache[target_type][target][
|
return self.target_user_tag_cache[target_type][target][f"{user_type}-{user}"]
|
||||||
"{}-{}".format(user_type, user)
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_next_target(self, target_type):
|
def get_next_target(self, target_type):
|
||||||
# FIXME 插入或删除target后对队列的影响(但是并不是大问题
|
# FIXME 插入或删除target后对队列的影响(但是并不是大问题
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, time
|
from datetime import time, datetime
|
||||||
from typing import Awaitable, Callable, Optional, Sequence
|
from collections.abc import Callable, Sequence, Awaitable
|
||||||
|
|
||||||
from nonebot_plugin_datastore import create_session
|
|
||||||
from nonebot_plugin_saa import PlatformTarget
|
|
||||||
from sqlalchemy import delete, func, select
|
|
||||||
from sqlalchemy.exc import IntegrityError
|
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy import func, delete, select
|
||||||
|
from nonebot_plugin_saa import PlatformTarget
|
||||||
|
from nonebot_plugin_datastore import create_session
|
||||||
|
|
||||||
from ..types import Category, PlatformWeightConfigResp, Tag
|
from ..types import Tag
|
||||||
from ..types import Target as T_Target
|
from ..types import Target as T_Target
|
||||||
from ..types import TimeWeightConfig, UserSubInfo, WeightConfig
|
|
||||||
from .db_model import ScheduleTimeWeight, Subscribe, Target, User
|
|
||||||
from .utils import NoSuchTargetException
|
from .utils import NoSuchTargetException
|
||||||
|
from .db_model import User, Target, Subscribe, ScheduleTimeWeight
|
||||||
|
from ..types import Category, UserSubInfo, WeightConfig, TimeWeightConfig, PlatformWeightConfigResp
|
||||||
|
|
||||||
|
|
||||||
def _get_time():
|
def _get_time():
|
||||||
@ -48,23 +48,17 @@ class DBConfig:
|
|||||||
):
|
):
|
||||||
async with create_session() as session:
|
async with create_session() as session:
|
||||||
db_user_stmt = select(User).where(User.user_target == user.dict())
|
db_user_stmt = select(User).where(User.user_target == user.dict())
|
||||||
db_user: Optional[User] = await session.scalar(db_user_stmt)
|
db_user: User | None = await session.scalar(db_user_stmt)
|
||||||
if not db_user:
|
if not db_user:
|
||||||
db_user = User(user_target=user.dict())
|
db_user = User(user_target=user.dict())
|
||||||
session.add(db_user)
|
session.add(db_user)
|
||||||
db_target_stmt = (
|
db_target_stmt = (
|
||||||
select(Target)
|
select(Target).where(Target.platform_name == platform_name).where(Target.target == target)
|
||||||
.where(Target.platform_name == platform_name)
|
|
||||||
.where(Target.target == target)
|
|
||||||
)
|
)
|
||||||
db_target: Optional[Target] = await session.scalar(db_target_stmt)
|
db_target: Target | None = await session.scalar(db_target_stmt)
|
||||||
if not db_target:
|
if not db_target:
|
||||||
db_target = Target(
|
db_target = Target(target=target, platform_name=platform_name, target_name=target_name)
|
||||||
target=target, platform_name=platform_name, target_name=target_name
|
await asyncio.gather(*[hook(platform_name, target) for hook in self.add_target_hook])
|
||||||
)
|
|
||||||
await asyncio.gather(
|
|
||||||
*[hook(platform_name, target) for hook in self.add_target_hook]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
db_target.target_name = target_name
|
db_target.target_name = target_name
|
||||||
subscribe = Subscribe(
|
subscribe = Subscribe(
|
||||||
@ -96,44 +90,25 @@ class DBConfig:
|
|||||||
"""获取数据库中带有user、target信息的subscribe数据"""
|
"""获取数据库中带有user、target信息的subscribe数据"""
|
||||||
async with create_session() as session:
|
async with create_session() as session:
|
||||||
query_stmt = (
|
query_stmt = (
|
||||||
select(Subscribe)
|
select(Subscribe).join(User).options(selectinload(Subscribe.target), selectinload(Subscribe.user))
|
||||||
.join(User)
|
|
||||||
.options(selectinload(Subscribe.target), selectinload(Subscribe.user))
|
|
||||||
)
|
)
|
||||||
subs = (await session.scalars(query_stmt)).all()
|
subs = (await session.scalars(query_stmt)).all()
|
||||||
|
|
||||||
return subs
|
return subs
|
||||||
|
|
||||||
async def del_subscribe(
|
async def del_subscribe(self, user: PlatformTarget, target: str, platform_name: str):
|
||||||
self, user: PlatformTarget, target: str, platform_name: str
|
|
||||||
):
|
|
||||||
async with create_session() as session:
|
async with create_session() as session:
|
||||||
user_obj = await session.scalar(
|
user_obj = await session.scalar(select(User).where(User.user_target == user.dict()))
|
||||||
select(User).where(User.user_target == user.dict())
|
|
||||||
)
|
|
||||||
target_obj = await session.scalar(
|
target_obj = await session.scalar(
|
||||||
select(Target).where(
|
select(Target).where(Target.platform_name == platform_name, Target.target == target)
|
||||||
Target.platform_name == platform_name, Target.target == target
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await session.execute(
|
|
||||||
delete(Subscribe).where(
|
|
||||||
Subscribe.user == user_obj, Subscribe.target == target_obj
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
await session.execute(delete(Subscribe).where(Subscribe.user == user_obj, Subscribe.target == target_obj))
|
||||||
target_count = await session.scalar(
|
target_count = await session.scalar(
|
||||||
select(func.count())
|
select(func.count()).select_from(Subscribe).where(Subscribe.target == target_obj)
|
||||||
.select_from(Subscribe)
|
|
||||||
.where(Subscribe.target == target_obj)
|
|
||||||
)
|
)
|
||||||
if target_count == 0:
|
if target_count == 0:
|
||||||
# delete empty target
|
# delete empty target
|
||||||
await asyncio.gather(
|
await asyncio.gather(*[hook(platform_name, T_Target(target)) for hook in self.delete_target_hook])
|
||||||
*[
|
|
||||||
hook(platform_name, T_Target(target))
|
|
||||||
for hook in self.delete_target_hook
|
|
||||||
]
|
|
||||||
)
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def update_subscribe(
|
async def update_subscribe(
|
||||||
@ -165,29 +140,22 @@ class DBConfig:
|
|||||||
async def get_platform_target(self, platform_name: str) -> Sequence[Target]:
|
async def get_platform_target(self, platform_name: str) -> Sequence[Target]:
|
||||||
async with create_session() as sess:
|
async with create_session() as sess:
|
||||||
subq = select(Subscribe.target_id).distinct().subquery()
|
subq = select(Subscribe.target_id).distinct().subquery()
|
||||||
query = (
|
query = select(Target).join(subq).where(Target.platform_name == platform_name)
|
||||||
select(Target).join(subq).where(Target.platform_name == platform_name)
|
|
||||||
)
|
|
||||||
return (await sess.scalars(query)).all()
|
return (await sess.scalars(query)).all()
|
||||||
|
|
||||||
async def get_time_weight_config(
|
async def get_time_weight_config(self, target: T_Target, platform_name: str) -> WeightConfig:
|
||||||
self, target: T_Target, platform_name: str
|
|
||||||
) -> WeightConfig:
|
|
||||||
async with create_session() as sess:
|
async with create_session() as sess:
|
||||||
time_weight_conf = (
|
time_weight_conf = (
|
||||||
await sess.scalars(
|
await sess.scalars(
|
||||||
select(ScheduleTimeWeight)
|
select(ScheduleTimeWeight)
|
||||||
.where(
|
.where(Target.platform_name == platform_name, Target.target == target)
|
||||||
Target.platform_name == platform_name, Target.target == target
|
|
||||||
)
|
|
||||||
.join(Target)
|
.join(Target)
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
targetObj = await sess.scalar(
|
targetObj = await sess.scalar(
|
||||||
select(Target).where(
|
select(Target).where(Target.platform_name == platform_name, Target.target == target)
|
||||||
Target.platform_name == platform_name, Target.target == target
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
assert targetObj
|
||||||
return WeightConfig(
|
return WeightConfig(
|
||||||
default=targetObj.default_schedule_weight,
|
default=targetObj.default_schedule_weight,
|
||||||
time_config=[
|
time_config=[
|
||||||
@ -200,22 +168,16 @@ class DBConfig:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def update_time_weight_config(
|
async def update_time_weight_config(self, target: T_Target, platform_name: str, conf: WeightConfig):
|
||||||
self, target: T_Target, platform_name: str, conf: WeightConfig
|
|
||||||
):
|
|
||||||
async with create_session() as sess:
|
async with create_session() as sess:
|
||||||
targetObj = await sess.scalar(
|
targetObj = await sess.scalar(
|
||||||
select(Target).where(
|
select(Target).where(Target.platform_name == platform_name, Target.target == target)
|
||||||
Target.platform_name == platform_name, Target.target == target
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if not targetObj:
|
if not targetObj:
|
||||||
raise NoSuchTargetException()
|
raise NoSuchTargetException()
|
||||||
target_id = targetObj.id
|
target_id = targetObj.id
|
||||||
targetObj.default_schedule_weight = conf.default
|
targetObj.default_schedule_weight = conf.default
|
||||||
delete_statement = delete(ScheduleTimeWeight).where(
|
delete_statement = delete(ScheduleTimeWeight).where(ScheduleTimeWeight.target_id == target_id)
|
||||||
ScheduleTimeWeight.target_id == target_id
|
|
||||||
)
|
|
||||||
await sess.execute(delete_statement)
|
await sess.execute(delete_statement)
|
||||||
for time_conf in conf.time_config:
|
for time_conf in conf.time_config:
|
||||||
new_conf = ScheduleTimeWeight(
|
new_conf = ScheduleTimeWeight(
|
||||||
@ -243,18 +205,13 @@ class DBConfig:
|
|||||||
key = f"{target.platform_name}-{target.target}"
|
key = f"{target.platform_name}-{target.target}"
|
||||||
weight = target.default_schedule_weight
|
weight = target.default_schedule_weight
|
||||||
for time_conf in target.time_weight:
|
for time_conf in target.time_weight:
|
||||||
if (
|
if time_conf.start_time <= cur_time and time_conf.end_time > cur_time:
|
||||||
time_conf.start_time <= cur_time
|
|
||||||
and time_conf.end_time > cur_time
|
|
||||||
):
|
|
||||||
weight = time_conf.weight
|
weight = time_conf.weight
|
||||||
break
|
break
|
||||||
res[key] = weight
|
res[key] = weight
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def get_platform_target_subscribers(
|
async def get_platform_target_subscribers(self, platform_name: str, target: T_Target) -> list[UserSubInfo]:
|
||||||
self, platform_name: str, target: T_Target
|
|
||||||
) -> list[UserSubInfo]:
|
|
||||||
async with create_session() as sess:
|
async with create_session() as sess:
|
||||||
query = (
|
query = (
|
||||||
select(Subscribe)
|
select(Subscribe)
|
||||||
@ -263,16 +220,14 @@ class DBConfig:
|
|||||||
.options(selectinload(Subscribe.user))
|
.options(selectinload(Subscribe.user))
|
||||||
)
|
)
|
||||||
subsribes = (await sess.scalars(query)).all()
|
subsribes = (await sess.scalars(query)).all()
|
||||||
return list(
|
return [
|
||||||
map(
|
UserSubInfo(
|
||||||
lambda subscribe: UserSubInfo(
|
|
||||||
PlatformTarget.deserialize(subscribe.user.user_target),
|
PlatformTarget.deserialize(subscribe.user.user_target),
|
||||||
subscribe.categories,
|
subscribe.categories,
|
||||||
subscribe.tags,
|
subscribe.tags,
|
||||||
),
|
|
||||||
subsribes,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for subscribe in subsribes
|
||||||
|
]
|
||||||
|
|
||||||
async def get_all_weight_config(
|
async def get_all_weight_config(
|
||||||
self,
|
self,
|
||||||
@ -281,9 +236,7 @@ class DBConfig:
|
|||||||
async with create_session() as sess:
|
async with create_session() as sess:
|
||||||
query = select(Target)
|
query = select(Target)
|
||||||
targets = (await sess.scalars(query)).all()
|
targets = (await sess.scalars(query)).all()
|
||||||
query = select(ScheduleTimeWeight).options(
|
query = select(ScheduleTimeWeight).options(selectinload(ScheduleTimeWeight.target))
|
||||||
selectinload(ScheduleTimeWeight.target)
|
|
||||||
)
|
|
||||||
time_weights = (await sess.scalars(query)).all()
|
time_weights = (await sess.scalars(query)).all()
|
||||||
|
|
||||||
for target in targets:
|
for target in targets:
|
||||||
@ -293,9 +246,7 @@ class DBConfig:
|
|||||||
target=T_Target(target.target),
|
target=T_Target(target.target),
|
||||||
target_name=target.target_name,
|
target_name=target.target_name,
|
||||||
platform_name=platform_name,
|
platform_name=platform_name,
|
||||||
weight=WeightConfig(
|
weight=WeightConfig(default=target.default_schedule_weight, time_config=[]),
|
||||||
default=target.default_schedule_weight, time_config=[]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for time_weight_config in time_weights:
|
for time_weight_config in time_weights:
|
||||||
|
@ -1,22 +1,17 @@
|
|||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot_plugin_datastore.db import get_engine
|
from nonebot_plugin_datastore.db import get_engine
|
||||||
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
|
|
||||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||||
|
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
|
||||||
|
|
||||||
|
from .db_model import User, Target, Subscribe
|
||||||
from .config_legacy import Config, ConfigContent, drop
|
from .config_legacy import Config, ConfigContent, drop
|
||||||
from .db_model import Subscribe, Target, User
|
|
||||||
|
|
||||||
|
|
||||||
async def data_migrate():
|
async def data_migrate():
|
||||||
config = Config()
|
config = Config()
|
||||||
if config.available:
|
if config.available:
|
||||||
logger.warning("You are still using legacy db, migrating to sqlite")
|
logger.warning("You are still using legacy db, migrating to sqlite")
|
||||||
all_subs: list[ConfigContent] = list(
|
all_subs: list[ConfigContent] = [ConfigContent(**item) for item in config.get_all_subscribe().all()]
|
||||||
map(
|
|
||||||
lambda item: ConfigContent(**item),
|
|
||||||
config.get_all_subscribe().all(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
async with AsyncSession(get_engine()) as sess:
|
async with AsyncSession(get_engine()) as sess:
|
||||||
user_to_create = []
|
user_to_create = []
|
||||||
subscribe_to_create = []
|
subscribe_to_create = []
|
||||||
@ -37,8 +32,7 @@ async def data_migrate():
|
|||||||
if key in user_sub_set:
|
if key in user_sub_set:
|
||||||
# a user subscribe a target twice
|
# a user subscribe a target twice
|
||||||
logger.error(
|
logger.error(
|
||||||
f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次,"
|
f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次,随机采用了一个订阅" # noqa: E501
|
||||||
"随机采用了一个订阅"
|
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
user_sub_set.add(key)
|
user_sub_set.add(key)
|
||||||
@ -69,11 +63,7 @@ async def data_migrate():
|
|||||||
tags=sub["tags"],
|
tags=sub["tags"],
|
||||||
)
|
)
|
||||||
subscribe_to_create.append(subscribe_obj)
|
subscribe_to_create.append(subscribe_obj)
|
||||||
sess.add_all(
|
sess.add_all(user_to_create + [x[0] for x in platform_target_map.values()] + subscribe_to_create)
|
||||||
user_to_create
|
|
||||||
+ list(map(lambda x: x[0], platform_target_map.values()))
|
|
||||||
+ subscribe_to_create
|
|
||||||
)
|
|
||||||
await sess.commit()
|
await sess.commit()
|
||||||
drop()
|
drop()
|
||||||
logger.info("migrate success")
|
logger.info("migrate success")
|
||||||
|
@ -5,7 +5,6 @@ Revises: 5f3370328e44
|
|||||||
Create Date: 2023-01-15 19:04:54.987491
|
Create Date: 2023-01-15 19:04:54.987491
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
|
@ -5,7 +5,6 @@ Revises: 0571870f5222
|
|||||||
Create Date: 2022-03-26 19:46:50.910721
|
Create Date: 2022-03-26 19:46:50.910721
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
@ -18,14 +17,10 @@ depends_on = None
|
|||||||
def upgrade():
|
def upgrade():
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
with op.batch_alter_table("subscribe", schema=None) as batch_op:
|
with op.batch_alter_table("subscribe", schema=None) as batch_op:
|
||||||
batch_op.create_unique_constraint(
|
batch_op.create_unique_constraint("unique-subscribe-constraint", ["target_id", "user_id"])
|
||||||
"unique-subscribe-constraint", ["target_id", "user_id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
with op.batch_alter_table("target", schema=None) as batch_op:
|
with op.batch_alter_table("target", schema=None) as batch_op:
|
||||||
batch_op.create_unique_constraint(
|
batch_op.create_unique_constraint("unique-target-constraint", ["target", "platform_name"])
|
||||||
"unique-target-constraint", ["target", "platform_name"]
|
|
||||||
)
|
|
||||||
|
|
||||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||||
batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"])
|
batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"])
|
||||||
|
@ -1,15 +1,14 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
|
||||||
from nonebot_plugin_saa.utils import AllSupportedPlatformTarget as UserInfo
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from nonebot_plugin_saa.utils import AllSupportedPlatformTarget as UserInfo
|
||||||
|
|
||||||
from ....types import Category, Tag
|
from ....types import Tag, Category
|
||||||
|
|
||||||
|
|
||||||
class NBESFBase(BaseModel, ABC):
|
class NBESFBase(BaseModel, ABC):
|
||||||
|
|
||||||
version: int # 表示nbesf格式版本,有效版本从1开始
|
version: int # 表示nbesf格式版本,有效版本从1开始
|
||||||
groups: list = list()
|
groups: list = []
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
@ -1,25 +1,26 @@
|
|||||||
|
from typing import cast
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Callable, cast
|
from collections.abc import Callable
|
||||||
|
|
||||||
from nonebot.log import logger
|
|
||||||
from nonebot_plugin_datastore.db import create_session
|
|
||||||
from nonebot_plugin_saa import PlatformTarget
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm.strategy_options import selectinload
|
from nonebot.log import logger
|
||||||
from sqlalchemy.sql.selectable import Select
|
from sqlalchemy.sql.selectable import Select
|
||||||
|
from nonebot_plugin_saa import PlatformTarget
|
||||||
|
from nonebot_plugin_datastore.db import create_session
|
||||||
|
from sqlalchemy.orm.strategy_options import selectinload
|
||||||
|
|
||||||
from ..db_model import Subscribe, User
|
|
||||||
from .nbesf_model import NBESFBase, v1, v2
|
|
||||||
from .utils import NBESFVerMatchErr
|
from .utils import NBESFVerMatchErr
|
||||||
|
from ..db_model import User, Subscribe
|
||||||
|
from .nbesf_model import NBESFBase, v1, v2
|
||||||
|
|
||||||
|
|
||||||
async def subscribes_export(selector: Callable[[Select], Select]) -> v2.SubGroup:
|
async def subscribes_export(selector: Callable[[Select], Select]) -> v2.SubGroup:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
将Bison订阅导出为 Nonebot Bison Exchangable Subscribes File 标准格式的 SubGroup 类型数据
|
将Bison订阅导出为 Nonebot Bison Exchangable Subscribes File 标准格式的 SubGroup 类型数据
|
||||||
|
|
||||||
selector:
|
selector:
|
||||||
对 sqlalchemy Select 对象的操作函数,用于限定查询范围 e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group")
|
对 sqlalchemy Select 对象的操作函数,用于限定查询范围
|
||||||
|
e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group")
|
||||||
"""
|
"""
|
||||||
async with create_session() as sess:
|
async with create_session() as sess:
|
||||||
sub_stmt = select(Subscribe).join(User)
|
sub_stmt = select(Subscribe).join(User)
|
||||||
|
@ -1,23 +1,22 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from importlib import import_module
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pkgutil import iter_modules
|
from pkgutil import iter_modules
|
||||||
from typing import DefaultDict, Type
|
from collections import defaultdict
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
from .platform import Platform, make_no_target_group
|
from .platform import Platform, make_no_target_group
|
||||||
|
|
||||||
_package_dir = str(Path(__file__).resolve().parent)
|
_package_dir = str(Path(__file__).resolve().parent)
|
||||||
for (_, module_name, _) in iter_modules([_package_dir]):
|
for _, module_name, _ in iter_modules([_package_dir]):
|
||||||
import_module(f"{__name__}.{module_name}")
|
import_module(f"{__name__}.{module_name}")
|
||||||
|
|
||||||
|
|
||||||
_platform_list: DefaultDict[str, list[Type[Platform]]] = defaultdict(list)
|
_platform_list: defaultdict[str, list[type[Platform]]] = defaultdict(list)
|
||||||
for _platform in Platform.registry:
|
for _platform in Platform.registry:
|
||||||
if not _platform.enabled:
|
if not _platform.enabled:
|
||||||
continue
|
continue
|
||||||
_platform_list[_platform.platform_name].append(_platform)
|
_platform_list[_platform.platform_name].append(_platform)
|
||||||
|
|
||||||
platform_manager: dict[str, Type[Platform]] = dict()
|
platform_manager: dict[str, type[Platform]] = {}
|
||||||
for name, platform_list in _platform_list.items():
|
for name, platform_list in _platform_list.items():
|
||||||
if len(platform_list) == 1:
|
if len(platform_list) == 1:
|
||||||
platform_manager[name] = platform_list[0]
|
platform_manager[name] = platform_list[0]
|
||||||
|
@ -1,25 +1,23 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from bs4 import BeautifulSoup as bs
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
from nonebot.plugin import require
|
from nonebot.plugin import require
|
||||||
|
from bs4 import BeautifulSoup as bs
|
||||||
|
|
||||||
from ..post import Post
|
from ..post import Post
|
||||||
from ..types import Category, RawPost, Target
|
from ..types import Target, RawPost, Category
|
||||||
from ..utils.scheduler_config import SchedulerConfig
|
from ..utils.scheduler_config import SchedulerConfig
|
||||||
from .platform import CategoryNotRecognize, NewMessage, StatusChange
|
from .platform import NewMessage, StatusChange, CategoryNotRecognize
|
||||||
|
|
||||||
|
|
||||||
class ArknightsSchedConf(SchedulerConfig):
|
class ArknightsSchedConf(SchedulerConfig):
|
||||||
|
|
||||||
name = "arknights"
|
name = "arknights"
|
||||||
schedule_type = "interval"
|
schedule_type = "interval"
|
||||||
schedule_setting = {"seconds": 30}
|
schedule_setting = {"seconds": 30}
|
||||||
|
|
||||||
|
|
||||||
class Arknights(NewMessage):
|
class Arknights(NewMessage):
|
||||||
|
|
||||||
categories = {1: "游戏公告"}
|
categories = {1: "游戏公告"}
|
||||||
platform_name = "arknights"
|
platform_name = "arknights"
|
||||||
name = "明日方舟游戏信息"
|
name = "明日方舟游戏信息"
|
||||||
@ -30,9 +28,7 @@ class Arknights(NewMessage):
|
|||||||
has_target = False
|
has_target = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_target_name(
|
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||||
cls, client: AsyncClient, target: Target
|
|
||||||
) -> Optional[str]:
|
|
||||||
return "明日方舟游戏信息"
|
return "明日方舟游戏信息"
|
||||||
|
|
||||||
async def get_sub_list(self, _) -> list[RawPost]:
|
async def get_sub_list(self, _) -> list[RawPost]:
|
||||||
@ -92,7 +88,6 @@ class Arknights(NewMessage):
|
|||||||
|
|
||||||
|
|
||||||
class AkVersion(StatusChange):
|
class AkVersion(StatusChange):
|
||||||
|
|
||||||
categories = {2: "更新信息"}
|
categories = {2: "更新信息"}
|
||||||
platform_name = "arknights"
|
platform_name = "arknights"
|
||||||
name = "明日方舟游戏信息"
|
name = "明日方舟游戏信息"
|
||||||
@ -103,15 +98,11 @@ class AkVersion(StatusChange):
|
|||||||
has_target = False
|
has_target = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_target_name(
|
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||||
cls, client: AsyncClient, target: Target
|
|
||||||
) -> Optional[str]:
|
|
||||||
return "明日方舟游戏信息"
|
return "明日方舟游戏信息"
|
||||||
|
|
||||||
async def get_status(self, _):
|
async def get_status(self, _):
|
||||||
res_ver = await self.client.get(
|
res_ver = await self.client.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version")
|
||||||
"https://ak-conf.hypergryph.com/config/prod/official/IOS/version"
|
|
||||||
)
|
|
||||||
res_preanounce = await self.client.get(
|
res_preanounce = await self.client.get(
|
||||||
"https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json"
|
"https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json"
|
||||||
)
|
)
|
||||||
@ -121,20 +112,10 @@ class AkVersion(StatusChange):
|
|||||||
|
|
||||||
def compare_status(self, _, old_status, new_status):
|
def compare_status(self, _, old_status, new_status):
|
||||||
res = []
|
res = []
|
||||||
if (
|
if old_status.get("preAnnounceType") == 2 and new_status.get("preAnnounceType") == 0:
|
||||||
old_status.get("preAnnounceType") == 2
|
res.append(Post("arknights", text="登录界面维护公告上线(大概是开始维护了)", target_name="明日方舟更新信息")) # noqa: E501
|
||||||
and new_status.get("preAnnounceType") == 0
|
elif old_status.get("preAnnounceType") == 0 and new_status.get("preAnnounceType") == 2:
|
||||||
):
|
res.append(Post("arknights", text="登录界面维护公告下线(大概是开服了,冲!)", target_name="明日方舟更新信息")) # noqa: E501
|
||||||
res.append(
|
|
||||||
Post("arknights", text="登录界面维护公告上线(大概是开始维护了)", target_name="明日方舟更新信息")
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
old_status.get("preAnnounceType") == 0
|
|
||||||
and new_status.get("preAnnounceType") == 2
|
|
||||||
):
|
|
||||||
res.append(
|
|
||||||
Post("arknights", text="登录界面维护公告下线(大概是开服了,冲!)", target_name="明日方舟更新信息")
|
|
||||||
)
|
|
||||||
if old_status.get("clientVersion") != new_status.get("clientVersion"):
|
if old_status.get("clientVersion") != new_status.get("clientVersion"):
|
||||||
res.append(Post("arknights", text="游戏本体更新(大更新)", target_name="明日方舟更新信息"))
|
res.append(Post("arknights", text="游戏本体更新(大更新)", target_name="明日方舟更新信息"))
|
||||||
if old_status.get("resVersion") != new_status.get("resVersion"):
|
if old_status.get("resVersion") != new_status.get("resVersion"):
|
||||||
@ -149,7 +130,6 @@ class AkVersion(StatusChange):
|
|||||||
|
|
||||||
|
|
||||||
class MonsterSiren(NewMessage):
|
class MonsterSiren(NewMessage):
|
||||||
|
|
||||||
categories = {3: "塞壬唱片新闻"}
|
categories = {3: "塞壬唱片新闻"}
|
||||||
platform_name = "arknights"
|
platform_name = "arknights"
|
||||||
name = "明日方舟游戏信息"
|
name = "明日方舟游戏信息"
|
||||||
@ -160,15 +140,11 @@ class MonsterSiren(NewMessage):
|
|||||||
has_target = False
|
has_target = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_target_name(
|
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||||
cls, client: AsyncClient, target: Target
|
|
||||||
) -> Optional[str]:
|
|
||||||
return "明日方舟游戏信息"
|
return "明日方舟游戏信息"
|
||||||
|
|
||||||
async def get_sub_list(self, _) -> list[RawPost]:
|
async def get_sub_list(self, _) -> list[RawPost]:
|
||||||
raw_data = await self.client.get(
|
raw_data = await self.client.get("https://monster-siren.hypergryph.com/api/news")
|
||||||
"https://monster-siren.hypergryph.com/api/news"
|
|
||||||
)
|
|
||||||
return raw_data.json()["data"]["list"]
|
return raw_data.json()["data"]["list"]
|
||||||
|
|
||||||
def get_id(self, post: RawPost) -> Any:
|
def get_id(self, post: RawPost) -> Any:
|
||||||
@ -182,14 +158,12 @@ class MonsterSiren(NewMessage):
|
|||||||
|
|
||||||
async def parse(self, raw_post: RawPost) -> Post:
|
async def parse(self, raw_post: RawPost) -> Post:
|
||||||
url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}'
|
url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}'
|
||||||
res = await self.client.get(
|
res = await self.client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}')
|
||||||
f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}'
|
|
||||||
)
|
|
||||||
raw_data = res.json()
|
raw_data = res.json()
|
||||||
content = raw_data["data"]["content"]
|
content = raw_data["data"]["content"]
|
||||||
content = content.replace("</p>", "</p>\n")
|
content = content.replace("</p>", "</p>\n")
|
||||||
soup = bs(content, "html.parser")
|
soup = bs(content, "html.parser")
|
||||||
imgs = list(map(lambda x: x["src"], soup("img")))
|
imgs = [x["src"] for x in soup("img")]
|
||||||
text = f'{raw_post["title"]}\n{soup.text.strip()}'
|
text = f'{raw_post["title"]}\n{soup.text.strip()}'
|
||||||
return Post(
|
return Post(
|
||||||
"monster-siren",
|
"monster-siren",
|
||||||
@ -203,7 +177,6 @@ class MonsterSiren(NewMessage):
|
|||||||
|
|
||||||
|
|
||||||
class TerraHistoricusComic(NewMessage):
|
class TerraHistoricusComic(NewMessage):
|
||||||
|
|
||||||
categories = {4: "泰拉记事社漫画"}
|
categories = {4: "泰拉记事社漫画"}
|
||||||
platform_name = "arknights"
|
platform_name = "arknights"
|
||||||
name = "明日方舟游戏信息"
|
name = "明日方舟游戏信息"
|
||||||
@ -214,15 +187,11 @@ class TerraHistoricusComic(NewMessage):
|
|||||||
has_target = False
|
has_target = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_target_name(
|
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||||
cls, client: AsyncClient, target: Target
|
|
||||||
) -> Optional[str]:
|
|
||||||
return "明日方舟游戏信息"
|
return "明日方舟游戏信息"
|
||||||
|
|
||||||
async def get_sub_list(self, _) -> list[RawPost]:
|
async def get_sub_list(self, _) -> list[RawPost]:
|
||||||
raw_data = await self.client.get(
|
raw_data = await self.client.get("https://terra-historicus.hypergryph.com/api/recentUpdate")
|
||||||
"https://terra-historicus.hypergryph.com/api/recentUpdate"
|
|
||||||
)
|
|
||||||
return raw_data.json()["data"]
|
return raw_data.json()["data"]
|
||||||
|
|
||||||
def get_id(self, post: RawPost) -> Any:
|
def get_id(self, post: RawPost) -> Any:
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import json
|
|
||||||
import re
|
import re
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import Any, Literal, Optional
|
from typing_extensions import Self
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import Field, BaseModel
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from ..post import Post
|
from ..post import Post
|
||||||
from ..types import ApiError, Category, RawPost, Tag, Target
|
from ..types import ApiError, Category, RawPost, Tag, Target
|
||||||
@ -25,9 +25,7 @@ class BilibiliSchedConf(SchedulerConfig):
|
|||||||
cookie_expire_time = timedelta(hours=5)
|
cookie_expire_time = timedelta(hours=5)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._client_refresh_time = datetime(
|
self._client_refresh_time = datetime(year=2000, month=1, day=1) # an expired time
|
||||||
year=2000, month=1, day=1
|
|
||||||
) # an expired time
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
async def _init_session(self):
|
async def _init_session(self):
|
||||||
@ -69,12 +67,8 @@ class Bilibili(NewMessage):
|
|||||||
parse_target_promot = "请输入用户主页的链接"
|
parse_target_promot = "请输入用户主页的链接"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_target_name(
|
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||||
cls, client: AsyncClient, target: Target
|
res = await client.get("https://api.bilibili.com/x/web-interface/card", params={"mid": target})
|
||||||
) -> Optional[str]:
|
|
||||||
res = await client.get(
|
|
||||||
"https://api.bilibili.com/x/web-interface/card", params={"mid": target}
|
|
||||||
)
|
|
||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
res_data = res.json()
|
res_data = res.json()
|
||||||
if res_data["code"]:
|
if res_data["code"]:
|
||||||
@ -129,12 +123,7 @@ class Bilibili(NewMessage):
|
|||||||
return self._do_get_category(post_type)
|
return self._do_get_category(post_type)
|
||||||
|
|
||||||
def get_tags(self, raw_post: RawPost) -> list[Tag]:
|
def get_tags(self, raw_post: RawPost) -> list[Tag]:
|
||||||
return [
|
return [*(tp["topic_name"] for tp in raw_post["display"]["topic_info"]["topic_details"])]
|
||||||
*map(
|
|
||||||
lambda tp: tp["topic_name"],
|
|
||||||
raw_post["display"]["topic_info"]["topic_details"],
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def _get_info(self, post_type: Category, card) -> tuple[str, list]:
|
def _get_info(self, post_type: Category, card) -> tuple[str, list]:
|
||||||
if post_type == 1:
|
if post_type == 1:
|
||||||
@ -178,24 +167,16 @@ class Bilibili(NewMessage):
|
|||||||
url = ""
|
url = ""
|
||||||
if post_type == 1:
|
if post_type == 1:
|
||||||
# 一般动态
|
# 一般动态
|
||||||
url = "https://t.bilibili.com/{}".format(
|
url = "https://t.bilibili.com/{}".format(raw_post["desc"]["dynamic_id_str"])
|
||||||
raw_post["desc"]["dynamic_id_str"]
|
|
||||||
)
|
|
||||||
elif post_type == 2:
|
elif post_type == 2:
|
||||||
# 专栏文章
|
# 专栏文章
|
||||||
url = "https://www.bilibili.com/read/cv{}".format(
|
url = "https://www.bilibili.com/read/cv{}".format(raw_post["desc"]["rid"])
|
||||||
raw_post["desc"]["rid"]
|
|
||||||
)
|
|
||||||
elif post_type == 3:
|
elif post_type == 3:
|
||||||
# 视频
|
# 视频
|
||||||
url = "https://www.bilibili.com/video/{}".format(
|
url = "https://www.bilibili.com/video/{}".format(raw_post["desc"]["bvid"])
|
||||||
raw_post["desc"]["bvid"]
|
|
||||||
)
|
|
||||||
elif post_type == 4:
|
elif post_type == 4:
|
||||||
# 纯文字
|
# 纯文字
|
||||||
url = "https://t.bilibili.com/{}".format(
|
url = "https://t.bilibili.com/{}".format(raw_post["desc"]["dynamic_id_str"])
|
||||||
raw_post["desc"]["dynamic_id_str"]
|
|
||||||
)
|
|
||||||
text, pic = self._get_info(post_type, card_content)
|
text, pic = self._get_info(post_type, card_content)
|
||||||
elif post_type == 5:
|
elif post_type == 5:
|
||||||
# 转发
|
# 转发
|
||||||
@ -261,10 +242,7 @@ class Bilibililive(StatusChange):
|
|||||||
def get_live_action(self, old_info: Self) -> "Bilibililive.LiveAction":
|
def get_live_action(self, old_info: Self) -> "Bilibililive.LiveAction":
|
||||||
status = Bilibililive.LiveStatus
|
status = Bilibililive.LiveStatus
|
||||||
action = Bilibililive.LiveAction
|
action = Bilibililive.LiveAction
|
||||||
if (
|
if old_info.live_status in [status.OFF, status.CYCLE] and self.live_status == status.ON:
|
||||||
old_info.live_status in [status.OFF, status.CYCLE]
|
|
||||||
and self.live_status == status.ON
|
|
||||||
):
|
|
||||||
return action.TURN_ON
|
return action.TURN_ON
|
||||||
elif old_info.live_status == status.ON and self.live_status in [
|
elif old_info.live_status == status.ON and self.live_status in [
|
||||||
status.OFF,
|
status.OFF,
|
||||||
@ -281,12 +259,8 @@ class Bilibililive(StatusChange):
|
|||||||
return action.OFF
|
return action.OFF
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_target_name(
|
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||||
cls, client: AsyncClient, target: Target
|
res = await client.get("https://api.bilibili.com/x/web-interface/card", params={"mid": target})
|
||||||
) -> Optional[str]:
|
|
||||||
res = await client.get(
|
|
||||||
"https://api.bilibili.com/x/web-interface/card", params={"mid": target}
|
|
||||||
)
|
|
||||||
res_data = json.loads(res.text)
|
res_data = json.loads(res.text)
|
||||||
if res_data["code"]:
|
if res_data["code"]:
|
||||||
return None
|
return None
|
||||||
@ -382,9 +356,7 @@ class BilibiliBangumi(StatusChange):
|
|||||||
_url = "https://api.bilibili.com/pgc/review/user"
|
_url = "https://api.bilibili.com/pgc/review/user"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_target_name(
|
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||||
cls, client: AsyncClient, target: Target
|
|
||||||
) -> Optional[str]:
|
|
||||||
res = await client.get(cls._url, params={"media_id": target})
|
res = await client.get(cls._url, params={"media_id": target})
|
||||||
res_data = res.json()
|
res_data = res.json()
|
||||||
if res_data["code"]:
|
if res_data["code"]:
|
||||||
|
@ -1,15 +1,14 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
from ..post import Post
|
from ..post import Post
|
||||||
from ..types import RawPost, Target
|
|
||||||
from ..utils import scheduler
|
from ..utils import scheduler
|
||||||
from .platform import NewMessage
|
from .platform import NewMessage
|
||||||
|
from ..types import Target, RawPost
|
||||||
|
|
||||||
|
|
||||||
class FF14(NewMessage):
|
class FF14(NewMessage):
|
||||||
|
|
||||||
categories = {}
|
categories = {}
|
||||||
platform_name = "ff14"
|
platform_name = "ff14"
|
||||||
name = "最终幻想XIV官方公告"
|
name = "最终幻想XIV官方公告"
|
||||||
@ -21,9 +20,7 @@ class FF14(NewMessage):
|
|||||||
has_target = False
|
has_target = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_target_name(
|
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||||
cls, client: AsyncClient, target: Target
|
|
||||||
) -> Optional[str]:
|
|
||||||
return "最终幻想XIV官方公告"
|
return "最终幻想XIV官方公告"
|
||||||
|
|
||||||
async def get_sub_list(self, _) -> list[RawPost]:
|
async def get_sub_list(self, _) -> list[RawPost]:
|
||||||
|
@ -2,15 +2,15 @@ import re
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from bs4 import BeautifulSoup, Tag
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
|
from bs4 import Tag, BeautifulSoup
|
||||||
from nonebot.plugin import require
|
from nonebot.plugin import require
|
||||||
|
|
||||||
from ..post import Post
|
from ..post import Post
|
||||||
from ..types import Category, RawPost, Target
|
from ..types import Target, RawPost, Category
|
||||||
from ..utils import SchedulerConfig, http_client
|
from ..utils import SchedulerConfig, http_client
|
||||||
from .platform import CategoryNotRecognize, CategoryNotSupport, NewMessage
|
from .platform import NewMessage, CategoryNotSupport, CategoryNotRecognize
|
||||||
|
|
||||||
|
|
||||||
class McbbsnewsSchedConf(SchedulerConfig):
|
class McbbsnewsSchedConf(SchedulerConfig):
|
||||||
@ -134,9 +134,9 @@ class McbbsNews(NewMessage):
|
|||||||
if categoty_name in category_values:
|
if categoty_name in category_values:
|
||||||
category_id = category_keys[category_values.index(categoty_name)]
|
category_id = category_keys[category_values.index(categoty_name)]
|
||||||
elif categoty_name in known_category_values:
|
elif categoty_name in known_category_values:
|
||||||
raise CategoryNotSupport("McbbsNews订阅暂不支持 {}".format(categoty_name))
|
raise CategoryNotSupport(f"McbbsNews订阅暂不支持 {categoty_name}")
|
||||||
else:
|
else:
|
||||||
raise CategoryNotRecognize("Mcbbsnews订阅尚未识别 {}".format(categoty_name))
|
raise CategoryNotRecognize(f"Mcbbsnews订阅尚未识别 {categoty_name}")
|
||||||
return category_id
|
return category_id
|
||||||
|
|
||||||
async def parse(self, post: RawPost) -> Post:
|
async def parse(self, post: RawPost) -> Post:
|
||||||
@ -170,7 +170,7 @@ class McbbsNews(NewMessage):
|
|||||||
一般而言每条新闻的长度都很可观,图片生成时间比较喜人
|
一般而言每条新闻的长度都很可观,图片生成时间比较喜人
|
||||||
"""
|
"""
|
||||||
require("nonebot_plugin_htmlrender")
|
require("nonebot_plugin_htmlrender")
|
||||||
from nonebot_plugin_htmlrender import capture_element, text_to_pic
|
from nonebot_plugin_htmlrender import text_to_pic, capture_element
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert url
|
assert url
|
||||||
@ -181,7 +181,7 @@ class McbbsNews(NewMessage):
|
|||||||
device_scale_factor=3,
|
device_scale_factor=3,
|
||||||
)
|
)
|
||||||
assert pic_data
|
assert pic_data
|
||||||
except:
|
except Exception:
|
||||||
err_info = traceback.format_exc()
|
err_info = traceback.format_exc()
|
||||||
logger.warning(f"渲染错误:{err_info}")
|
logger.warning(f"渲染错误:{err_info}")
|
||||||
|
|
||||||
|
@ -1,23 +1,21 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
from ..post import Post
|
from ..post import Post
|
||||||
from ..types import ApiError, RawPost, Target
|
|
||||||
from ..utils import SchedulerConfig
|
|
||||||
from .platform import NewMessage
|
from .platform import NewMessage
|
||||||
|
from ..utils import SchedulerConfig
|
||||||
|
from ..types import Target, RawPost, ApiError
|
||||||
|
|
||||||
|
|
||||||
class NcmSchedConf(SchedulerConfig):
|
class NcmSchedConf(SchedulerConfig):
|
||||||
|
|
||||||
name = "music.163.com"
|
name = "music.163.com"
|
||||||
schedule_type = "interval"
|
schedule_type = "interval"
|
||||||
schedule_setting = {"minutes": 1}
|
schedule_setting = {"minutes": 1}
|
||||||
|
|
||||||
|
|
||||||
class NcmArtist(NewMessage):
|
class NcmArtist(NewMessage):
|
||||||
|
|
||||||
categories = {}
|
categories = {}
|
||||||
platform_name = "ncm-artist"
|
platform_name = "ncm-artist"
|
||||||
enable_tag = False
|
enable_tag = False
|
||||||
@ -29,11 +27,9 @@ class NcmArtist(NewMessage):
|
|||||||
parse_target_promot = "请输入歌手主页(包含数字ID)的链接"
|
parse_target_promot = "请输入歌手主页(包含数字ID)的链接"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_target_name(
|
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||||
cls, client: AsyncClient, target: Target
|
|
||||||
) -> Optional[str]:
|
|
||||||
res = await client.get(
|
res = await client.get(
|
||||||
"https://music.163.com/api/artist/albums/{}".format(target),
|
f"https://music.163.com/api/artist/albums/{target}",
|
||||||
headers={"Referer": "https://music.163.com/"},
|
headers={"Referer": "https://music.163.com/"},
|
||||||
)
|
)
|
||||||
res_data = res.json()
|
res_data = res.json()
|
||||||
@ -45,16 +41,14 @@ class NcmArtist(NewMessage):
|
|||||||
async def parse_target(cls, target_text: str) -> Target:
|
async def parse_target(cls, target_text: str) -> Target:
|
||||||
if re.match(r"^\d+$", target_text):
|
if re.match(r"^\d+$", target_text):
|
||||||
return Target(target_text)
|
return Target(target_text)
|
||||||
elif match := re.match(
|
elif match := re.match(r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text):
|
||||||
r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text
|
|
||||||
):
|
|
||||||
return Target(match.group(1))
|
return Target(match.group(1))
|
||||||
else:
|
else:
|
||||||
raise cls.ParseTargetException()
|
raise cls.ParseTargetException()
|
||||||
|
|
||||||
async def get_sub_list(self, target: Target) -> list[RawPost]:
|
async def get_sub_list(self, target: Target) -> list[RawPost]:
|
||||||
res = await self.client.get(
|
res = await self.client.get(
|
||||||
"https://music.163.com/api/artist/albums/{}".format(target),
|
f"https://music.163.com/api/artist/albums/{target}",
|
||||||
headers={"Referer": "https://music.163.com/"},
|
headers={"Referer": "https://music.163.com/"},
|
||||||
)
|
)
|
||||||
res_data = res.json()
|
res_data = res.json()
|
||||||
@ -74,13 +68,10 @@ class NcmArtist(NewMessage):
|
|||||||
target_name = raw_post["artist"]["name"]
|
target_name = raw_post["artist"]["name"]
|
||||||
pics = [raw_post["picUrl"]]
|
pics = [raw_post["picUrl"]]
|
||||||
url = "https://music.163.com/#/album?id={}".format(raw_post["id"])
|
url = "https://music.163.com/#/album?id={}".format(raw_post["id"])
|
||||||
return Post(
|
return Post("ncm-artist", text=text, url=url, pics=pics, target_name=target_name)
|
||||||
"ncm-artist", text=text, url=url, pics=pics, target_name=target_name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class NcmRadio(NewMessage):
|
class NcmRadio(NewMessage):
|
||||||
|
|
||||||
categories = {}
|
categories = {}
|
||||||
platform_name = "ncm-radio"
|
platform_name = "ncm-radio"
|
||||||
enable_tag = False
|
enable_tag = False
|
||||||
@ -92,9 +83,7 @@ class NcmRadio(NewMessage):
|
|||||||
parse_target_promot = "请输入主播电台主页(包含数字ID)的链接"
|
parse_target_promot = "请输入主播电台主页(包含数字ID)的链接"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_target_name(
|
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||||
cls, client: AsyncClient, target: Target
|
|
||||||
) -> Optional[str]:
|
|
||||||
res = await client.post(
|
res = await client.post(
|
||||||
"http://music.163.com/api/dj/program/byradio",
|
"http://music.163.com/api/dj/program/byradio",
|
||||||
headers={"Referer": "https://music.163.com/"},
|
headers={"Referer": "https://music.163.com/"},
|
||||||
@ -109,9 +98,7 @@ class NcmRadio(NewMessage):
|
|||||||
async def parse_target(cls, target_text: str) -> Target:
|
async def parse_target(cls, target_text: str) -> Target:
|
||||||
if re.match(r"^\d+$", target_text):
|
if re.match(r"^\d+$", target_text):
|
||||||
return Target(target_text)
|
return Target(target_text)
|
||||||
elif match := re.match(
|
elif match := re.match(r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text):
|
||||||
r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text
|
|
||||||
):
|
|
||||||
return Target(match.group(1))
|
return Target(match.group(1))
|
||||||
else:
|
else:
|
||||||
raise cls.ParseTargetException()
|
raise cls.ParseTargetException()
|
||||||
|
@ -1,29 +1,32 @@
|
|||||||
import json
|
|
||||||
import ssl
|
import ssl
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
|
from typing import Any
|
||||||
|
from dataclasses import dataclass
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from collections.abc import Collection
|
||||||
from typing import Any, Collection, Optional, Type
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot_plugin_saa import PlatformTarget
|
from nonebot_plugin_saa import PlatformTarget
|
||||||
|
|
||||||
from ..plugin_config import plugin_config
|
|
||||||
from ..post import Post
|
from ..post import Post
|
||||||
from ..types import Category, RawPost, Tag, Target, UserSubInfo
|
from ..plugin_config import plugin_config
|
||||||
from ..utils import ProcessContext, SchedulerConfig
|
from ..utils import ProcessContext, SchedulerConfig
|
||||||
|
from ..types import Tag, Target, RawPost, Category, UserSubInfo
|
||||||
|
|
||||||
|
|
||||||
class CategoryNotSupport(Exception):
|
class CategoryNotSupport(Exception):
|
||||||
"raise in get_category, when you know the category of the post but don't want to support it or don't support its parsing yet"
|
"""raise in get_category, when you know the category of the post
|
||||||
|
but don't want to support it or don't support its parsing yet
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class CategoryNotRecognize(Exception):
|
class CategoryNotRecognize(Exception):
|
||||||
"raise in get_category, when you don't know the category of post"
|
"""raise in get_category, when you don't know the category of post"""
|
||||||
|
|
||||||
|
|
||||||
class RegistryMeta(type):
|
class RegistryMeta(type):
|
||||||
@ -42,7 +45,6 @@ class RegistryMeta(type):
|
|||||||
|
|
||||||
|
|
||||||
class PlatformMeta(RegistryMeta):
|
class PlatformMeta(RegistryMeta):
|
||||||
|
|
||||||
categories: dict[Category, str]
|
categories: dict[Category, str]
|
||||||
store: dict[Target, Any]
|
store: dict[Target, Any]
|
||||||
|
|
||||||
@ -60,8 +62,7 @@ class PlatformABCMeta(PlatformMeta, ABC):
|
|||||||
|
|
||||||
|
|
||||||
class Platform(metaclass=PlatformABCMeta, base=True):
|
class Platform(metaclass=PlatformABCMeta, base=True):
|
||||||
|
scheduler: type[SchedulerConfig]
|
||||||
scheduler: Type[SchedulerConfig]
|
|
||||||
ctx: ProcessContext
|
ctx: ProcessContext
|
||||||
is_common: bool
|
is_common: bool
|
||||||
enabled: bool
|
enabled: bool
|
||||||
@ -70,16 +71,14 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
|||||||
categories: dict[Category, str]
|
categories: dict[Category, str]
|
||||||
enable_tag: bool
|
enable_tag: bool
|
||||||
platform_name: str
|
platform_name: str
|
||||||
parse_target_promot: Optional[str] = None
|
parse_target_promot: str | None = None
|
||||||
registry: list[Type["Platform"]]
|
registry: list[type["Platform"]]
|
||||||
client: AsyncClient
|
client: AsyncClient
|
||||||
reverse_category: dict[str, Category]
|
reverse_category: dict[str, Category]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_target_name(
|
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||||
cls, client: AsyncClient, target: Target
|
|
||||||
) -> Optional[str]:
|
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -95,11 +94,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
|||||||
return await self.fetch_new_post(target, users)
|
return await self.fetch_new_post(target, users)
|
||||||
except httpx.RequestError as err:
|
except httpx.RequestError as err:
|
||||||
if plugin_config.bison_show_network_warning:
|
if plugin_config.bison_show_network_warning:
|
||||||
logger.warning(
|
logger.warning(f"network connection error: {type(err)}, url: {err.request.url}")
|
||||||
"network connection error: {}, url: {}".format(
|
|
||||||
type(err), err.request.url
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return []
|
return []
|
||||||
except ssl.SSLError as err:
|
except ssl.SSLError as err:
|
||||||
if plugin_config.bison_show_network_warning:
|
if plugin_config.bison_show_network_warning:
|
||||||
@ -130,7 +125,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
|||||||
return Target(target_string)
|
return Target(target_string)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]:
|
def get_tags(self, raw_post: RawPost) -> Collection[Tag] | None:
|
||||||
"Return Tag list of given RawPost"
|
"Return Tag list of given RawPost"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -201,9 +196,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
|||||||
) -> list[tuple[PlatformTarget, list[Post]]]:
|
) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||||
res: list[tuple[PlatformTarget, list[Post]]] = []
|
res: list[tuple[PlatformTarget, list[Post]]] = []
|
||||||
for user, cats, required_tags in users:
|
for user, cats, required_tags in users:
|
||||||
user_raw_post = await self.filter_user_custom(
|
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
|
||||||
new_posts, cats, required_tags
|
|
||||||
)
|
|
||||||
user_post: list[Post] = []
|
user_post: list[Post] = []
|
||||||
for raw_post in user_raw_post:
|
for raw_post in user_raw_post:
|
||||||
user_post.append(await self.do_parse(raw_post))
|
user_post.append(await self.do_parse(raw_post))
|
||||||
@ -211,7 +204,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_category(self, post: RawPost) -> Optional[Category]:
|
def get_category(self, post: RawPost) -> Category | None:
|
||||||
"Return category of given Rawpost"
|
"Return category of given Rawpost"
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -221,7 +214,7 @@ class MessageProcess(Platform, abstract=True):
|
|||||||
|
|
||||||
def __init__(self, ctx: ProcessContext, client: AsyncClient):
|
def __init__(self, ctx: ProcessContext, client: AsyncClient):
|
||||||
super().__init__(ctx, client)
|
super().__init__(ctx, client)
|
||||||
self.parse_cache: dict[Any, Post] = dict()
|
self.parse_cache: dict[Any, Post] = {}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_id(self, post: RawPost) -> Any:
|
def get_id(self, post: RawPost) -> Any:
|
||||||
@ -246,7 +239,7 @@ class MessageProcess(Platform, abstract=True):
|
|||||||
"Get post list of the given target"
|
"Get post list of the given target"
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_date(self, post: RawPost) -> Optional[int]:
|
def get_date(self, post: RawPost) -> int | None:
|
||||||
"Get post timestamp and return, return None if can't get the time"
|
"Get post timestamp and return, return None if can't get the time"
|
||||||
|
|
||||||
async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]:
|
async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]:
|
||||||
@ -286,9 +279,7 @@ class NewMessage(MessageProcess, abstract=True):
|
|||||||
inited: bool
|
inited: bool
|
||||||
exists_posts: set[Any]
|
exists_posts: set[Any]
|
||||||
|
|
||||||
async def filter_common_with_diff(
|
async def filter_common_with_diff(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]:
|
||||||
self, target: Target, raw_post_list: list[RawPost]
|
|
||||||
) -> list[RawPost]:
|
|
||||||
filtered_post = await self.filter_common(raw_post_list)
|
filtered_post = await self.filter_common(raw_post_list)
|
||||||
store = self.get_stored_data(target) or self.MessageStorage(False, set())
|
store = self.get_stored_data(target) or self.MessageStorage(False, set())
|
||||||
res = []
|
res = []
|
||||||
@ -297,11 +288,7 @@ class NewMessage(MessageProcess, abstract=True):
|
|||||||
for raw_post in filtered_post:
|
for raw_post in filtered_post:
|
||||||
post_id = self.get_id(raw_post)
|
post_id = self.get_id(raw_post)
|
||||||
store.exists_posts.add(post_id)
|
store.exists_posts.add(post_id)
|
||||||
logger.info(
|
logger.info(f"init {self.platform_name}-{target} with {store.exists_posts}")
|
||||||
"init {}-{} with {}".format(
|
|
||||||
self.platform_name, target, store.exists_posts
|
|
||||||
)
|
|
||||||
)
|
|
||||||
store.inited = True
|
store.inited = True
|
||||||
else:
|
else:
|
||||||
for raw_post in filtered_post:
|
for raw_post in filtered_post:
|
||||||
@ -400,12 +387,11 @@ class SimplePost(MessageProcess, abstract=True):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
|
def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]:
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
|
|
||||||
class NoTargetGroup(Platform, abstract=True):
|
class NoTargetGroup(Platform, abstract=True):
|
||||||
platform_list: list[Type[Platform]]
|
platform_list: list[type[Platform]]
|
||||||
platform_obj_list: list[Platform]
|
platform_obj_list: list[Platform]
|
||||||
|
|
||||||
DUMMY_STR = "_DUMMY"
|
DUMMY_STR = "_DUMMY"
|
||||||
@ -418,24 +404,18 @@ def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
|
|||||||
|
|
||||||
for platform in platform_list:
|
for platform in platform_list:
|
||||||
if platform.has_target:
|
if platform.has_target:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Platform {platform.name} should have no target")
|
||||||
"Platform {} should have no target".format(platform.name)
|
|
||||||
)
|
|
||||||
if name == DUMMY_STR:
|
if name == DUMMY_STR:
|
||||||
name = platform.name
|
name = platform.name
|
||||||
elif name != platform.name:
|
elif name != platform.name:
|
||||||
raise RuntimeError("Platform name for {} not fit".format(platform_name))
|
raise RuntimeError(f"Platform name for {platform_name} not fit")
|
||||||
platform_category_key_set = set(platform.categories.keys())
|
platform_category_key_set = set(platform.categories.keys())
|
||||||
if platform_category_key_set & categories_keys:
|
if platform_category_key_set & categories_keys:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Platform categories for {platform_name} duplicate")
|
||||||
"Platform categories for {} duplicate".format(platform_name)
|
|
||||||
)
|
|
||||||
categories_keys |= platform_category_key_set
|
categories_keys |= platform_category_key_set
|
||||||
categories.update(platform.categories)
|
categories.update(platform.categories)
|
||||||
if platform.scheduler != scheduler:
|
if platform.scheduler != scheduler:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Platform scheduler for {platform_name} not fit")
|
||||||
"Platform scheduler for {} not fit".format(platform_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient):
|
def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient):
|
||||||
Platform.__init__(self, ctx, client)
|
Platform.__init__(self, ctx, client)
|
||||||
@ -444,15 +424,13 @@ def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
|
|||||||
self.platform_obj_list.append(platform_class(ctx, client))
|
self.platform_obj_list.append(platform_class(ctx, client))
|
||||||
|
|
||||||
def __str__(self: "NoTargetGroup") -> str:
|
def __str__(self: "NoTargetGroup") -> str:
|
||||||
return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]"
|
return "[" + " ".join(x.name for x in self.platform_list) + "]"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_target_name(cls, client: AsyncClient, target: Target):
|
async def get_target_name(cls, client: AsyncClient, target: Target):
|
||||||
return await platform_list[0].get_target_name(client, target)
|
return await platform_list[0].get_target_name(client, target)
|
||||||
|
|
||||||
async def fetch_new_post(
|
async def fetch_new_post(self: "NoTargetGroup", target: Target, users: list[UserSubInfo]):
|
||||||
self: "NoTargetGroup", target: Target, users: list[UserSubInfo]
|
|
||||||
):
|
|
||||||
res = defaultdict(list)
|
res = defaultdict(list)
|
||||||
for platform in self.platform_obj_list:
|
for platform in self.platform_obj_list:
|
||||||
platform_res = await platform.fetch_new_post(target=target, users=users)
|
platform_res = await platform.fetch_new_post(target=target, users=users)
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import calendar
|
import calendar
|
||||||
import time
|
import time
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
import feedparser
|
import feedparser
|
||||||
from bs4 import BeautifulSoup as bs
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
from bs4 import BeautifulSoup as bs
|
||||||
|
|
||||||
from ..post import Post
|
from ..post import Post
|
||||||
from ..types import RawPost, Target
|
from ..types import RawPost, Target
|
||||||
@ -20,7 +20,6 @@ class RssSchedConf(SchedulerConfig):
|
|||||||
|
|
||||||
|
|
||||||
class Rss(NewMessage):
|
class Rss(NewMessage):
|
||||||
|
|
||||||
categories = {}
|
categories = {}
|
||||||
enable_tag = False
|
enable_tag = False
|
||||||
platform_name = "rss"
|
platform_name = "rss"
|
||||||
@ -31,9 +30,7 @@ class Rss(NewMessage):
|
|||||||
has_target = True
|
has_target = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_target_name(
|
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||||
cls, client: AsyncClient, target: Target
|
|
||||||
) -> Optional[str]:
|
|
||||||
res = await client.get(target, timeout=10.0)
|
res = await client.get(target, timeout=10.0)
|
||||||
feed = feedparser.parse(res.text)
|
feed = feedparser.parse(res.text)
|
||||||
return feed["feed"]["title"]
|
return feed["feed"]["title"]
|
||||||
@ -69,7 +66,7 @@ class Rss(NewMessage):
|
|||||||
else:
|
else:
|
||||||
text = f"{title}\n\n{desc}"
|
text = f"{title}\n\n{desc}"
|
||||||
|
|
||||||
pics = list(map(lambda x: x.attrs["src"], soup("img")))
|
pics = [x.attrs["src"] for x in soup("img")]
|
||||||
if raw_post.get("media_content"):
|
if raw_post.get("media_content"):
|
||||||
for media in raw_post["media_content"]:
|
for media in raw_post["media_content"]:
|
||||||
if media.get("medium") == "image" and media.get("url"):
|
if media.get("medium") == "image" and media.get("url"):
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
import json
|
|
||||||
import re
|
import re
|
||||||
from collections.abc import Callable
|
import json
|
||||||
|
from typing import Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from bs4 import BeautifulSoup as bs
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
|
from bs4 import BeautifulSoup as bs
|
||||||
|
|
||||||
from ..post import Post
|
from ..post import Post
|
||||||
from ..types import *
|
|
||||||
from ..utils import SchedulerConfig, http_client
|
|
||||||
from .platform import NewMessage
|
from .platform import NewMessage
|
||||||
|
from ..utils import SchedulerConfig, http_client
|
||||||
|
from ..types import Tag, Target, RawPost, ApiError, Category
|
||||||
|
|
||||||
|
|
||||||
class WeiboSchedConf(SchedulerConfig):
|
class WeiboSchedConf(SchedulerConfig):
|
||||||
@ -21,7 +20,6 @@ class WeiboSchedConf(SchedulerConfig):
|
|||||||
|
|
||||||
|
|
||||||
class Weibo(NewMessage):
|
class Weibo(NewMessage):
|
||||||
|
|
||||||
categories = {
|
categories = {
|
||||||
1: "转发",
|
1: "转发",
|
||||||
2: "视频",
|
2: "视频",
|
||||||
@ -38,13 +36,9 @@ class Weibo(NewMessage):
|
|||||||
parse_target_promot = "请输入用户主页(包含数字UID)的链接"
|
parse_target_promot = "请输入用户主页(包含数字UID)的链接"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_target_name(
|
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None:
|
||||||
cls, client: AsyncClient, target: Target
|
|
||||||
) -> Optional[str]:
|
|
||||||
param = {"containerid": "100505" + target}
|
param = {"containerid": "100505" + target}
|
||||||
res = await client.get(
|
res = await client.get("https://m.weibo.cn/api/container/getIndex", params=param)
|
||||||
"https://m.weibo.cn/api/container/getIndex", params=param
|
|
||||||
)
|
|
||||||
res_dict = json.loads(res.text)
|
res_dict = json.loads(res.text)
|
||||||
if res_dict.get("ok") == 1:
|
if res_dict.get("ok") == 1:
|
||||||
return res_dict["data"]["userInfo"]["screen_name"]
|
return res_dict["data"]["userInfo"]["screen_name"]
|
||||||
@ -63,13 +57,14 @@ class Weibo(NewMessage):
|
|||||||
|
|
||||||
async def get_sub_list(self, target: Target) -> list[RawPost]:
|
async def get_sub_list(self, target: Target) -> list[RawPost]:
|
||||||
params = {"containerid": "107603" + target}
|
params = {"containerid": "107603" + target}
|
||||||
res = await self.client.get(
|
res = await self.client.get("https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0)
|
||||||
"https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0
|
|
||||||
)
|
|
||||||
res_data = json.loads(res.text)
|
res_data = json.loads(res.text)
|
||||||
if not res_data["ok"] and res_data["msg"] != "这里还没有内容":
|
if not res_data["ok"] and res_data["msg"] != "这里还没有内容":
|
||||||
raise ApiError(res.request.url)
|
raise ApiError(res.request.url)
|
||||||
custom_filter: Callable[[RawPost], bool] = lambda d: d["card_type"] == 9
|
|
||||||
|
def custom_filter(d: RawPost) -> bool:
|
||||||
|
return d["card_type"] == 9
|
||||||
|
|
||||||
return list(filter(custom_filter, res_data["data"]["cards"]))
|
return list(filter(custom_filter, res_data["data"]["cards"]))
|
||||||
|
|
||||||
def get_id(self, post: RawPost) -> Any:
|
def get_id(self, post: RawPost) -> Any:
|
||||||
@ -79,44 +74,32 @@ class Weibo(NewMessage):
|
|||||||
return raw_post["card_type"] == 9
|
return raw_post["card_type"] == 9
|
||||||
|
|
||||||
def get_date(self, raw_post: RawPost) -> float:
|
def get_date(self, raw_post: RawPost) -> float:
|
||||||
created_time = datetime.strptime(
|
created_time = datetime.strptime(raw_post["mblog"]["created_at"], "%a %b %d %H:%M:%S %z %Y")
|
||||||
raw_post["mblog"]["created_at"], "%a %b %d %H:%M:%S %z %Y"
|
|
||||||
)
|
|
||||||
return created_time.timestamp()
|
return created_time.timestamp()
|
||||||
|
|
||||||
def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]:
|
def get_tags(self, raw_post: RawPost) -> list[Tag] | None:
|
||||||
"Return Tag list of given RawPost"
|
"Return Tag list of given RawPost"
|
||||||
text = raw_post["mblog"]["text"]
|
text = raw_post["mblog"]["text"]
|
||||||
soup = bs(text, "html.parser")
|
soup = bs(text, "html.parser")
|
||||||
res = list(
|
res = [
|
||||||
map(
|
x[1:-1]
|
||||||
lambda x: x[1:-1],
|
for x in filter(
|
||||||
filter(
|
|
||||||
lambda s: s[0] == "#" and s[-1] == "#",
|
lambda s: s[0] == "#" and s[-1] == "#",
|
||||||
map(lambda x: x.text, soup.find_all("span", class_="surl-text")),
|
(x.text for x in soup.find_all("span", class_="surl-text")),
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
super_topic_img = soup.find(
|
|
||||||
"img", src=re.compile(r"timeline_card_small_super_default")
|
|
||||||
)
|
)
|
||||||
|
]
|
||||||
|
super_topic_img = soup.find("img", src=re.compile(r"timeline_card_small_super_default"))
|
||||||
if super_topic_img:
|
if super_topic_img:
|
||||||
try:
|
try:
|
||||||
res.append(
|
res.append(super_topic_img.parent.parent.find("span", class_="surl-text").text + "超话") # type: ignore
|
||||||
super_topic_img.parent.parent.find("span", class_="surl-text").text # type: ignore
|
except Exception:
|
||||||
+ "超话"
|
logger.info(f"super_topic extract error: {text}")
|
||||||
)
|
|
||||||
except:
|
|
||||||
logger.info("super_topic extract error: {}".format(text))
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def get_category(self, raw_post: RawPost) -> Category:
|
def get_category(self, raw_post: RawPost) -> Category:
|
||||||
if raw_post["mblog"].get("retweeted_status"):
|
if raw_post["mblog"].get("retweeted_status"):
|
||||||
return Category(1)
|
return Category(1)
|
||||||
elif (
|
elif raw_post["mblog"].get("page_info") and raw_post["mblog"]["page_info"].get("type") == "video":
|
||||||
raw_post["mblog"].get("page_info")
|
|
||||||
and raw_post["mblog"]["page_info"].get("type") == "video"
|
|
||||||
):
|
|
||||||
return Category(2)
|
return Category(2)
|
||||||
elif raw_post["mblog"].get("pics"):
|
elif raw_post["mblog"].get("pics"):
|
||||||
return Category(3)
|
return Category(3)
|
||||||
@ -129,7 +112,8 @@ class Weibo(NewMessage):
|
|||||||
|
|
||||||
async def parse(self, raw_post: RawPost) -> Post:
|
async def parse(self, raw_post: RawPost) -> Post:
|
||||||
header = {
|
header = {
|
||||||
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
|
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,"
|
||||||
|
"*/*;q=0.8,application/signed-exchange;v=b3;q=0.9",
|
||||||
"accept-language": "zh-CN,zh;q=0.9",
|
"accept-language": "zh-CN,zh;q=0.9",
|
||||||
"authority": "m.weibo.cn",
|
"authority": "m.weibo.cn",
|
||||||
"cache-control": "max-age=0",
|
"cache-control": "max-age=0",
|
||||||
@ -147,26 +131,16 @@ class Weibo(NewMessage):
|
|||||||
retweeted = True
|
retweeted = True
|
||||||
pic_num = info["retweeted_status"]["pic_num"] if retweeted else info["pic_num"]
|
pic_num = info["retweeted_status"]["pic_num"] if retweeted else info["pic_num"]
|
||||||
if info["isLongText"] or pic_num > 9:
|
if info["isLongText"] or pic_num > 9:
|
||||||
res = await self.client.get(
|
res = await self.client.get(f"https://m.weibo.cn/detail/{info['mid']}", headers=header)
|
||||||
"https://m.weibo.cn/detail/{}".format(info["mid"]), headers=header
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
match = re.search(r'"status": ([\s\S]+),\s+"call"', res.text)
|
match = re.search(r'"status": ([\s\S]+),\s+"call"', res.text)
|
||||||
assert match
|
assert match
|
||||||
full_json_text = match.group(1)
|
full_json_text = match.group(1)
|
||||||
info = json.loads(full_json_text)
|
info = json.loads(full_json_text)
|
||||||
except:
|
except Exception:
|
||||||
logger.info(
|
logger.info(f"detail message error: https://m.weibo.cn/detail/{info['mid']}")
|
||||||
"detail message error: https://m.weibo.cn/detail/{}".format(
|
|
||||||
info["mid"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
parsed_text = self._get_text(info["text"])
|
parsed_text = self._get_text(info["text"])
|
||||||
raw_pics_list = (
|
raw_pics_list = info["retweeted_status"].get("pics", []) if retweeted else info.get("pics", [])
|
||||||
info["retweeted_status"].get("pics", [])
|
|
||||||
if retweeted
|
|
||||||
else info.get("pics", [])
|
|
||||||
)
|
|
||||||
pic_urls = [img["large"]["url"] for img in raw_pics_list]
|
pic_urls = [img["large"]["url"] for img in raw_pics_list]
|
||||||
pics = []
|
pics = []
|
||||||
for pic_url in pic_urls:
|
for pic_url in pic_urls:
|
||||||
@ -174,7 +148,7 @@ class Weibo(NewMessage):
|
|||||||
res = await client.get(pic_url)
|
res = await client.get(pic_url)
|
||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
pics.append(res.content)
|
pics.append(res.content)
|
||||||
detail_url = "https://weibo.com/{}/{}".format(info["user"]["id"], info["bid"])
|
detail_url = f"https://weibo.com/{info['user']['id']}/{info['bid']}"
|
||||||
# return parsed_text, detail_url, pic_urls
|
# return parsed_text, detail_url, pic_urls
|
||||||
return Post(
|
return Post(
|
||||||
"weibo",
|
"weibo",
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import nonebot
|
import nonebot
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
class PlugConfig(BaseSettings):
|
class PlugConfig(BaseSettings):
|
||||||
|
|
||||||
bison_config_path: str = ""
|
bison_config_path: str = ""
|
||||||
bison_use_pic: bool = False
|
bison_use_pic: bool = False
|
||||||
bison_init_filter: bool = True
|
bison_init_filter: bool = True
|
||||||
@ -17,8 +14,12 @@ class PlugConfig(BaseSettings):
|
|||||||
bison_use_pic_merge: int = 0 # 多图片时启用图片合并转发(仅限群)
|
bison_use_pic_merge: int = 0 # 多图片时启用图片合并转发(仅限群)
|
||||||
# 0:不启用;1:首条消息单独发送,剩余照片合并转发;2以及以上:所有消息全部合并转发
|
# 0:不启用;1:首条消息单独发送,剩余照片合并转发;2以及以上:所有消息全部合并转发
|
||||||
bison_resend_times: int = 0
|
bison_resend_times: int = 0
|
||||||
bison_proxy: Optional[str]
|
bison_proxy: str | None
|
||||||
bison_ua: str = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.103 Safari/537.36"
|
bison_ua: str = (
|
||||||
|
"Mozilla/5.0 (X11; Linux x86_64) "
|
||||||
|
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||||
|
"Chrome/51.0.2704.103 Safari/537.36"
|
||||||
|
)
|
||||||
bison_show_network_warning: bool = True
|
bison_show_network_warning: bool = True
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
from .post import Post
|
from .post import Post
|
||||||
|
|
||||||
__all__ = ["Post", "CustomPost"]
|
__all__ = ["Post"]
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from abc import abstractmethod
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Optional
|
from abc import abstractmethod
|
||||||
|
from dataclasses import field, dataclass
|
||||||
|
|
||||||
from nonebot_plugin_saa import MessageFactory, MessageSegmentFactory
|
from nonebot_plugin_saa import MessageFactory, MessageSegmentFactory
|
||||||
|
|
||||||
@ -25,12 +24,12 @@ class BasePost:
|
|||||||
class OptionalMixin:
|
class OptionalMixin:
|
||||||
# Because of https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
|
# Because of https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
|
||||||
|
|
||||||
override_use_pic: Optional[bool] = None
|
override_use_pic: bool | None = None
|
||||||
compress: bool = False
|
compress: bool = False
|
||||||
extra_msg: list[MessageFactory] = field(default_factory=list)
|
extra_msg: list[MessageFactory] = field(default_factory=list)
|
||||||
|
|
||||||
def _use_pic(self):
|
def _use_pic(self):
|
||||||
if not self.override_use_pic is None:
|
if self.override_use_pic is not None:
|
||||||
return self.override_use_pic
|
return self.override_use_pic
|
||||||
return plugin_config.bison_use_pic
|
return plugin_config.bison_use_pic
|
||||||
|
|
||||||
@ -44,13 +43,9 @@ class AbstractPost(OptionalMixin, BasePost):
|
|||||||
msg_segments = await self.generate_text_messages()
|
msg_segments = await self.generate_text_messages()
|
||||||
if msg_segments:
|
if msg_segments:
|
||||||
if self.compress:
|
if self.compress:
|
||||||
msgs = [
|
msgs = [reduce(lambda x, y: x.append(y), msg_segments, MessageFactory([]))]
|
||||||
reduce(lambda x, y: x.append(y), msg_segments, MessageFactory([]))
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
msgs = list(
|
msgs = [MessageFactory([msg_segment]) for msg_segment in msg_segments]
|
||||||
map(lambda msg_segment: MessageFactory([msg_segment]), msg_segments)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
msgs = []
|
msgs = []
|
||||||
msgs.extend(self.extra_msg)
|
msgs.extend(self.extra_msg)
|
||||||
|
@ -1,19 +1,17 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import field, dataclass
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from nonebot.adapters.onebot.v11 import MessageSegment
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.plugin import require
|
from nonebot.plugin import require
|
||||||
from nonebot_plugin_saa import Image, MessageFactory, MessageSegmentFactory, Text
|
from nonebot.adapters.onebot.v11 import MessageSegment
|
||||||
|
from nonebot_plugin_saa import Text, Image, MessageSegmentFactory
|
||||||
|
|
||||||
from .abstract_post import AbstractPost, BasePost
|
from .abstract_post import BasePost, AbstractPost
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _CustomPost(BasePost):
|
class _CustomPost(BasePost):
|
||||||
|
|
||||||
ms_factories: list[MessageSegmentFactory] = field(default_factory=list)
|
ms_factories: list[MessageSegmentFactory] = field(default_factory=list)
|
||||||
css_path: Optional[str] = None # 模板文件所用css路径
|
css_path: str | None = None # 模板文件所用css路径
|
||||||
|
|
||||||
async def generate_text_messages(self) -> list[MessageSegmentFactory]:
|
async def generate_text_messages(self) -> list[MessageSegmentFactory]:
|
||||||
return self.ms_factories
|
return self.ms_factories
|
||||||
@ -31,15 +29,13 @@ class _CustomPost(BasePost):
|
|||||||
for message_segment in self.ms_factories:
|
for message_segment in self.ms_factories:
|
||||||
match message_segment:
|
match message_segment:
|
||||||
case Text(data={"text": text}):
|
case Text(data={"text": text}):
|
||||||
md += "{}<br>".format(text)
|
md += f"{text}<br>"
|
||||||
case Image(data={"image": image}):
|
case Image(data={"image": image}):
|
||||||
# use onebot v11 to convert image into url
|
# use onebot v11 to convert image into url
|
||||||
ob11_image = MessageSegment.image(image)
|
ob11_image = MessageSegment.image(image)
|
||||||
md += "\n".format(ob11_image.data["file"])
|
md += "\n".format(ob11_image.data["file"])
|
||||||
case _:
|
case _:
|
||||||
logger.warning(
|
logger.warning(f"custom_post不支持处理类型:{type(message_segment)}")
|
||||||
"custom_post不支持处理类型:{}".format(type(message_segment))
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return md
|
return md
|
||||||
|
@ -1,30 +1,27 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
from functools import reduce
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional, Union
|
from dataclasses import field, dataclass
|
||||||
|
|
||||||
import nonebot_plugin_saa as saa
|
|
||||||
from nonebot.log import logger
|
|
||||||
from nonebot_plugin_saa.utils import MessageFactory, MessageSegmentFactory
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from nonebot.log import logger
|
||||||
|
import nonebot_plugin_saa as saa
|
||||||
|
from nonebot_plugin_saa.utils import MessageSegmentFactory
|
||||||
|
|
||||||
from ..utils import http_client, parse_text
|
from ..utils import parse_text, http_client
|
||||||
from .abstract_post import AbstractPost, BasePost, OptionalMixin
|
from .abstract_post import BasePost, AbstractPost
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _Post(BasePost):
|
class _Post(BasePost):
|
||||||
|
|
||||||
target_type: str
|
target_type: str
|
||||||
text: str
|
text: str
|
||||||
url: Optional[str] = None
|
url: str | None = None
|
||||||
target_name: Optional[str] = None
|
target_name: str | None = None
|
||||||
pics: list[Union[str, bytes]] = field(default_factory=list)
|
pics: list[str | bytes] = field(default_factory=list)
|
||||||
|
|
||||||
_message: Optional[list[MessageSegmentFactory]] = None
|
_message: list[MessageSegmentFactory] | None = None
|
||||||
_pic_message: Optional[list[MessageSegmentFactory]] = None
|
_pic_message: list[MessageSegmentFactory] | None = None
|
||||||
|
|
||||||
async def _pic_url_to_image(self, data: Union[str, bytes]) -> Image.Image:
|
async def _pic_url_to_image(self, data: str | bytes) -> Image.Image:
|
||||||
pic_buffer = BytesIO()
|
pic_buffer = BytesIO()
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
async with http_client() as client:
|
async with http_client() as client:
|
||||||
@ -101,22 +98,19 @@ class _Post(BasePost):
|
|||||||
self.pics.insert(0, target_io.getvalue())
|
self.pics.insert(0, target_io.getvalue())
|
||||||
|
|
||||||
async def generate_text_messages(self) -> list[MessageSegmentFactory]:
|
async def generate_text_messages(self) -> list[MessageSegmentFactory]:
|
||||||
|
|
||||||
if self._message is None:
|
if self._message is None:
|
||||||
await self._pic_merge()
|
await self._pic_merge()
|
||||||
msg_segments: list[MessageSegmentFactory] = []
|
msg_segments: list[MessageSegmentFactory] = []
|
||||||
text = ""
|
text = ""
|
||||||
if self.text:
|
if self.text:
|
||||||
text += "{}".format(
|
text += "{}".format(self.text if len(self.text) < 500 else self.text[:500] + "...")
|
||||||
self.text if len(self.text) < 500 else self.text[:500] + "..."
|
|
||||||
)
|
|
||||||
if text:
|
if text:
|
||||||
text += "\n"
|
text += "\n"
|
||||||
text += "来源: {}".format(self.target_type)
|
text += f"来源: {self.target_type}"
|
||||||
if self.target_name:
|
if self.target_name:
|
||||||
text += " {}".format(self.target_name)
|
text += f" {self.target_name}"
|
||||||
if self.url:
|
if self.url:
|
||||||
text += " \n详情: {}".format(self.url)
|
text += f" \n详情: {self.url}"
|
||||||
msg_segments.append(saa.Text(text))
|
msg_segments.append(saa.Text(text))
|
||||||
for pic in self.pics:
|
for pic in self.pics:
|
||||||
msg_segments.append(saa.Image(pic))
|
msg_segments.append(saa.Image(pic))
|
||||||
@ -124,17 +118,16 @@ class _Post(BasePost):
|
|||||||
return self._message
|
return self._message
|
||||||
|
|
||||||
async def generate_pic_messages(self) -> list[MessageSegmentFactory]:
|
async def generate_pic_messages(self) -> list[MessageSegmentFactory]:
|
||||||
|
|
||||||
if self._pic_message is None:
|
if self._pic_message is None:
|
||||||
await self._pic_merge()
|
await self._pic_merge()
|
||||||
msg_segments: list[MessageSegmentFactory] = []
|
msg_segments: list[MessageSegmentFactory] = []
|
||||||
text = ""
|
text = ""
|
||||||
if self.text:
|
if self.text:
|
||||||
text += "{}".format(self.text)
|
text += f"{self.text}"
|
||||||
text += "\n"
|
text += "\n"
|
||||||
text += "来源: {}".format(self.target_type)
|
text += f"来源: {self.target_type}"
|
||||||
if self.target_name:
|
if self.target_name:
|
||||||
text += " {}".format(self.target_name)
|
text += f" {self.target_name}"
|
||||||
msg_segments.append(await parse_text(text))
|
msg_segments.append(await parse_text(text))
|
||||||
if not self.target_type == "rss" and self.url:
|
if not self.target_type == "rss" and self.url:
|
||||||
msg_segments.append(saa.Text(self.url))
|
msg_segments.append(saa.Text(self.url))
|
||||||
@ -149,14 +142,7 @@ class _Post(BasePost):
|
|||||||
self.target_name,
|
self.target_name,
|
||||||
self.text if len(self.text) < 500 else self.text[:500] + "...",
|
self.text if len(self.text) < 500 else self.text[:500] + "...",
|
||||||
self.url,
|
self.url,
|
||||||
", ".join(
|
", ".join("b64img" if isinstance(x, bytes) or x.startswith("base64") else x for x in self.pics),
|
||||||
map(
|
|
||||||
lambda x: "b64img"
|
|
||||||
if isinstance(x, bytes) or x.startswith("base64")
|
|
||||||
else x,
|
|
||||||
self.pics,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1 +1,3 @@
|
|||||||
from .manager import *
|
from .manager import init_scheduler, scheduler_dict, handle_delete_target, handle_insert_new_target
|
||||||
|
|
||||||
|
__all__ = ["init_scheduler", "handle_delete_target", "handle_insert_new_target", "scheduler_dict"]
|
||||||
|
@ -1,18 +1,16 @@
|
|||||||
from typing import Type
|
|
||||||
|
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..config.db_model import Target
|
|
||||||
from ..platform import platform_manager
|
|
||||||
from ..types import Target as T_Target
|
|
||||||
from ..utils import SchedulerConfig
|
|
||||||
from .scheduler import Scheduler
|
from .scheduler import Scheduler
|
||||||
|
from ..utils import SchedulerConfig
|
||||||
|
from ..config.db_model import Target
|
||||||
|
from ..types import Target as T_Target
|
||||||
|
from ..platform import platform_manager
|
||||||
|
|
||||||
scheduler_dict: dict[Type[SchedulerConfig], Scheduler] = {}
|
scheduler_dict: dict[type[SchedulerConfig], Scheduler] = {}
|
||||||
|
|
||||||
|
|
||||||
async def init_scheduler():
|
async def init_scheduler():
|
||||||
_schedule_class_dict: dict[Type[SchedulerConfig], list[Target]] = {}
|
_schedule_class_dict: dict[type[SchedulerConfig], list[Target]] = {}
|
||||||
_schedule_class_platform_dict: dict[Type[SchedulerConfig], list[str]] = {}
|
_schedule_class_platform_dict: dict[type[SchedulerConfig], list[str]] = {}
|
||||||
for platform in platform_manager.values():
|
for platform in platform_manager.values():
|
||||||
scheduler_config = platform.scheduler
|
scheduler_config = platform.scheduler
|
||||||
if not hasattr(scheduler_config, "name") or not scheduler_config.name:
|
if not hasattr(scheduler_config, "name") or not scheduler_config.name:
|
||||||
@ -33,9 +31,7 @@ async def init_scheduler():
|
|||||||
for target in target_list:
|
for target in target_list:
|
||||||
schedulable_args.append((target.platform_name, T_Target(target.target)))
|
schedulable_args.append((target.platform_name, T_Target(target.target)))
|
||||||
platform_name_list = _schedule_class_platform_dict[scheduler_config]
|
platform_name_list = _schedule_class_platform_dict[scheduler_config]
|
||||||
scheduler_dict[scheduler_config] = Scheduler(
|
scheduler_dict[scheduler_config] = Scheduler(scheduler_config, schedulable_args, platform_name_list)
|
||||||
scheduler_config, schedulable_args, platform_name_list
|
|
||||||
)
|
|
||||||
config.register_add_target_hook(handle_insert_new_target)
|
config.register_add_target_hook(handle_insert_new_target)
|
||||||
config.register_delete_target_hook(handle_delete_target)
|
config.register_delete_target_hook(handle_delete_target)
|
||||||
|
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Type
|
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot_plugin_apscheduler import scheduler
|
from nonebot_plugin_apscheduler import scheduler
|
||||||
from nonebot_plugin_saa.utils.exceptions import NoBotFound
|
from nonebot_plugin_saa.utils.exceptions import NoBotFound
|
||||||
|
|
||||||
from ..config import config
|
|
||||||
from ..platform import platform_manager
|
|
||||||
from ..send import send_msgs
|
|
||||||
from ..types import Target
|
from ..types import Target
|
||||||
|
from ..config import config
|
||||||
|
from ..send import send_msgs
|
||||||
|
from ..platform import platform_manager
|
||||||
from ..utils import ProcessContext, SchedulerConfig
|
from ..utils import ProcessContext, SchedulerConfig
|
||||||
|
|
||||||
|
|
||||||
@ -20,12 +19,11 @@ class Schedulable:
|
|||||||
|
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler:
|
||||||
|
|
||||||
schedulable_list: list[Schedulable]
|
schedulable_list: list[Schedulable]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
scheduler_config: Type[SchedulerConfig],
|
scheduler_config: type[SchedulerConfig],
|
||||||
schedulables: list[tuple[str, Target]],
|
schedulables: list[tuple[str, Target]],
|
||||||
platform_name_list: list[str],
|
platform_name_list: list[str],
|
||||||
):
|
):
|
||||||
@ -37,15 +35,12 @@ class Scheduler:
|
|||||||
self.scheduler_config_obj = self.scheduler_config()
|
self.scheduler_config_obj = self.scheduler_config()
|
||||||
self.schedulable_list = []
|
self.schedulable_list = []
|
||||||
for platform_name, target in schedulables:
|
for platform_name, target in schedulables:
|
||||||
self.schedulable_list.append(
|
self.schedulable_list.append(Schedulable(platform_name=platform_name, target=target, current_weight=0))
|
||||||
Schedulable(
|
|
||||||
platform_name=platform_name, target=target, current_weight=0
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.platform_name_list = platform_name_list
|
self.platform_name_list = platform_name_list
|
||||||
self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值
|
self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值
|
||||||
logger.info(
|
logger.info(
|
||||||
f"register scheduler for {self.name} with {self.scheduler_config.schedule_type} {self.scheduler_config.schedule_setting}"
|
f"register scheduler for {self.name} with "
|
||||||
|
f"{self.scheduler_config.schedule_type} {self.scheduler_config.schedule_setting}"
|
||||||
)
|
)
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
self.exec_fetch,
|
self.exec_fetch,
|
||||||
@ -53,7 +48,7 @@ class Scheduler:
|
|||||||
**self.scheduler_config.schedule_setting,
|
**self.scheduler_config.schedule_setting,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_next_schedulable(self) -> Optional[Schedulable]:
|
async def get_next_schedulable(self) -> Schedulable | None:
|
||||||
if not self.schedulable_list:
|
if not self.schedulable_list:
|
||||||
return None
|
return None
|
||||||
cur_weight = await config.get_current_weight_val(self.platform_name_list)
|
cur_weight = await config.get_current_weight_val(self.platform_name_list)
|
||||||
@ -61,16 +56,9 @@ class Scheduler:
|
|||||||
self.pre_weight_val = 0
|
self.pre_weight_val = 0
|
||||||
cur_max_schedulable = None
|
cur_max_schedulable = None
|
||||||
for schedulable in self.schedulable_list:
|
for schedulable in self.schedulable_list:
|
||||||
schedulable.current_weight += cur_weight[
|
schedulable.current_weight += cur_weight[f"{schedulable.platform_name}-{schedulable.target}"]
|
||||||
f"{schedulable.platform_name}-{schedulable.target}"
|
weight_sum += cur_weight[f"{schedulable.platform_name}-{schedulable.target}"]
|
||||||
]
|
if not cur_max_schedulable or cur_max_schedulable.current_weight < schedulable.current_weight:
|
||||||
weight_sum += cur_weight[
|
|
||||||
f"{schedulable.platform_name}-{schedulable.target}"
|
|
||||||
]
|
|
||||||
if (
|
|
||||||
not cur_max_schedulable
|
|
||||||
or cur_max_schedulable.current_weight < schedulable.current_weight
|
|
||||||
):
|
|
||||||
cur_max_schedulable = schedulable
|
cur_max_schedulable = schedulable
|
||||||
assert cur_max_schedulable
|
assert cur_max_schedulable
|
||||||
cur_max_schedulable.current_weight -= weight_sum
|
cur_max_schedulable.current_weight -= weight_sum
|
||||||
@ -80,9 +68,7 @@ class Scheduler:
|
|||||||
context = ProcessContext()
|
context = ProcessContext()
|
||||||
if not (schedulable := await self.get_next_schedulable()):
|
if not (schedulable := await self.get_next_schedulable()):
|
||||||
return
|
return
|
||||||
logger.trace(
|
logger.trace(f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}")
|
||||||
f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}"
|
|
||||||
)
|
|
||||||
send_userinfo_list = await config.get_platform_target_subscribers(
|
send_userinfo_list = await config.get_platform_target_subscribers(
|
||||||
schedulable.platform_name, schedulable.target
|
schedulable.platform_name, schedulable.target
|
||||||
)
|
)
|
||||||
@ -92,9 +78,7 @@ class Scheduler:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
platform_obj = platform_manager[schedulable.platform_name](context, client)
|
platform_obj = platform_manager[schedulable.platform_name](context, client)
|
||||||
to_send = await platform_obj.do_fetch_new_post(
|
to_send = await platform_obj.do_fetch_new_post(schedulable.target, send_userinfo_list)
|
||||||
schedulable.target, send_userinfo_list
|
|
||||||
)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
records = context.gen_req_records()
|
records = context.gen_req_records()
|
||||||
for record in records:
|
for record in records:
|
||||||
@ -107,7 +91,7 @@ class Scheduler:
|
|||||||
|
|
||||||
for user, send_list in to_send:
|
for user, send_list in to_send:
|
||||||
for send_post in send_list:
|
for send_post in send_list:
|
||||||
logger.info("send to {}: {}".format(user, send_post))
|
logger.info(f"send to {user}: {send_post}")
|
||||||
try:
|
try:
|
||||||
await send_msgs(
|
await send_msgs(
|
||||||
user,
|
user,
|
||||||
@ -119,19 +103,14 @@ class Scheduler:
|
|||||||
def insert_new_schedulable(self, platform_name: str, target: Target):
|
def insert_new_schedulable(self, platform_name: str, target: Target):
|
||||||
self.pre_weight_val += 1000
|
self.pre_weight_val += 1000
|
||||||
self.schedulable_list.append(Schedulable(platform_name, target, 1000))
|
self.schedulable_list.append(Schedulable(platform_name, target, 1000))
|
||||||
logger.info(
|
logger.info(f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})")
|
||||||
f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})"
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete_schedulable(self, platform_name, target: Target):
|
def delete_schedulable(self, platform_name, target: Target):
|
||||||
if not self.schedulable_list:
|
if not self.schedulable_list:
|
||||||
return
|
return
|
||||||
to_find_idx = None
|
to_find_idx = None
|
||||||
for idx, schedulable in enumerate(self.schedulable_list):
|
for idx, schedulable in enumerate(self.schedulable_list):
|
||||||
if (
|
if schedulable.platform_name == platform_name and schedulable.target == target:
|
||||||
schedulable.platform_name == platform_name
|
|
||||||
and schedulable.target == target
|
|
||||||
):
|
|
||||||
to_find_idx = idx
|
to_find_idx = idx
|
||||||
break
|
break
|
||||||
if to_find_idx is not None:
|
if to_find_idx is not None:
|
||||||
|
@ -1,21 +1,23 @@
|
|||||||
import importlib
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from functools import partial, wraps
|
import importlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Any, Callable, Coroutine, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
from functools import wraps, partial
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
|
|
||||||
from ..config.subs_io import subscribes_export, subscribes_import
|
|
||||||
from ..config.subs_io.nbesf_model import v1, v2
|
|
||||||
from ..scheduler.manager import init_scheduler
|
from ..scheduler.manager import init_scheduler
|
||||||
|
from ..config.subs_io.nbesf_model import v1, v2
|
||||||
|
from ..config.subs_io import subscribes_export, subscribes_import
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import click
|
import click
|
||||||
from typing_extensions import ParamSpec
|
|
||||||
except ImportError as e: # pragma: no cover
|
except ImportError as e: # pragma: no cover
|
||||||
raise ImportError("请使用 `pip install nonebot-bison[cli]` 安装所需依赖") from e
|
raise ImportError("请使用 `pip install nonebot-bison[cli]` 安装所需依赖") from e
|
||||||
|
|
||||||
@ -65,9 +67,7 @@ def path_init(ctx, param, value):
|
|||||||
|
|
||||||
|
|
||||||
@cli.command(help="导出Nonebot Bison Exchangable Subcribes File", name="export")
|
@cli.command(help="导出Nonebot Bison Exchangable Subcribes File", name="export")
|
||||||
@click.option(
|
@click.option("--path", "-p", default=None, callback=path_init, help="导出路径, 如果不指定,则默认为工作目录")
|
||||||
"--path", "-p", default=None, callback=path_init, help="导出路径, 如果不指定,则默认为工作目录"
|
|
||||||
)
|
|
||||||
@click.option(
|
@click.option(
|
||||||
"--format",
|
"--format",
|
||||||
default="json",
|
default="json",
|
||||||
@ -76,7 +76,6 @@ def path_init(ctx, param, value):
|
|||||||
)
|
)
|
||||||
@run_async
|
@run_async
|
||||||
async def subs_export(path: Path, format: str):
|
async def subs_export(path: Path, format: str):
|
||||||
|
|
||||||
await init_scheduler()
|
await init_scheduler()
|
||||||
|
|
||||||
export_file = path / f"bison_subscribes_export_{int(time.time())}.{format}"
|
export_file = path / f"bison_subscribes_export_{int(time.time())}.{format}"
|
||||||
@ -121,7 +120,6 @@ async def subs_export(path: Path, format: str):
|
|||||||
)
|
)
|
||||||
@run_async
|
@run_async
|
||||||
async def subs_import(path: str, format: str):
|
async def subs_import(path: str, format: str):
|
||||||
|
|
||||||
await init_scheduler()
|
await init_scheduler()
|
||||||
|
|
||||||
import_file_path = Path(path)
|
import_file_path = Path(path)
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Deque
|
|
||||||
|
|
||||||
from nonebot.adapters.onebot.v11.exception import ActionFailed
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot_plugin_saa import AggregatedMessageFactory, MessageFactory, PlatformTarget
|
from nonebot.adapters.onebot.v11.exception import ActionFailed
|
||||||
from nonebot_plugin_saa.utils.auto_select_bot import refresh_bots
|
from nonebot_plugin_saa.utils.auto_select_bot import refresh_bots
|
||||||
|
from nonebot_plugin_saa import MessageFactory, PlatformTarget, AggregatedMessageFactory
|
||||||
|
|
||||||
from .plugin_config import plugin_config
|
from .plugin_config import plugin_config
|
||||||
|
|
||||||
Sendable = MessageFactory | AggregatedMessageFactory
|
Sendable = MessageFactory | AggregatedMessageFactory
|
||||||
|
|
||||||
QUEUE: Deque[tuple[PlatformTarget, Sendable, int]] = deque()
|
QUEUE: deque[tuple[PlatformTarget, Sendable, int]] = deque()
|
||||||
|
|
||||||
MESSGE_SEND_INTERVAL = 1.5
|
MESSGE_SEND_INTERVAL = 1.5
|
||||||
|
|
||||||
|
@ -1,21 +1,20 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
from typing import Type, cast
|
|
||||||
|
|
||||||
from nonebot.adapters import Message, MessageTemplate
|
from nonebot.typing import T_State
|
||||||
from nonebot.matcher import Matcher
|
from nonebot.matcher import Matcher
|
||||||
from nonebot.params import Arg, ArgPlainText
|
from nonebot.params import Arg, ArgPlainText
|
||||||
from nonebot.typing import T_State
|
from nonebot.adapters import Message, MessageTemplate
|
||||||
from nonebot_plugin_saa import PlatformTarget, SupportedAdapters, Text
|
from nonebot_plugin_saa import Text, PlatformTarget, SupportedAdapters
|
||||||
|
|
||||||
from ..apis import check_sub_target
|
|
||||||
from ..config import config
|
|
||||||
from ..config.db_config import SubscribeDupException
|
|
||||||
from ..platform import Platform, platform_manager
|
|
||||||
from ..types import Target
|
from ..types import Target
|
||||||
|
from ..config import config
|
||||||
|
from ..apis import check_sub_target
|
||||||
|
from ..platform import Platform, platform_manager
|
||||||
|
from ..config.db_config import SubscribeDupException
|
||||||
from .utils import common_platform, ensure_user_info, gen_handle_cancel
|
from .utils import common_platform, ensure_user_info, gen_handle_cancel
|
||||||
|
|
||||||
|
|
||||||
def do_add_sub(add_sub: Type[Matcher]):
|
def do_add_sub(add_sub: type[Matcher]):
|
||||||
handle_cancel = gen_handle_cancel(add_sub, "已中止订阅")
|
handle_cancel = gen_handle_cancel(add_sub, "已中止订阅")
|
||||||
|
|
||||||
add_sub.handle()(ensure_user_info(add_sub))
|
add_sub.handle()(ensure_user_info(add_sub))
|
||||||
@ -25,12 +24,7 @@ def do_add_sub(add_sub: Type[Matcher]):
|
|||||||
state["_prompt"] = (
|
state["_prompt"] = (
|
||||||
"请输入想要订阅的平台,目前支持,请输入冒号左边的名称:\n"
|
"请输入想要订阅的平台,目前支持,请输入冒号左边的名称:\n"
|
||||||
+ "".join(
|
+ "".join(
|
||||||
[
|
[f"{platform_name}: {platform_manager[platform_name].name}\n" for platform_name in common_platform]
|
||||||
"{}:{}\n".format(
|
|
||||||
platform_name, platform_manager[platform_name].name
|
|
||||||
)
|
|
||||||
for platform_name in common_platform
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
+ "要查看全部平台请输入:“全部”\n中止订阅过程请输入:“取消”"
|
+ "要查看全部平台请输入:“全部”\n中止订阅过程请输入:“取消”"
|
||||||
)
|
)
|
||||||
@ -39,10 +33,7 @@ def do_add_sub(add_sub: Type[Matcher]):
|
|||||||
async def parse_platform(state: T_State, platform: str = ArgPlainText()) -> None:
|
async def parse_platform(state: T_State, platform: str = ArgPlainText()) -> None:
|
||||||
if platform == "全部":
|
if platform == "全部":
|
||||||
message = "全部平台\n" + "\n".join(
|
message = "全部平台\n" + "\n".join(
|
||||||
[
|
[f"{platform_name}: {platform.name}" for platform_name, platform in platform_manager.items()]
|
||||||
"{}:{}".format(platform_name, platform.name)
|
|
||||||
for platform_name, platform in platform_manager.items()
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
await add_sub.reject(message)
|
await add_sub.reject(message)
|
||||||
elif platform == "取消":
|
elif platform == "取消":
|
||||||
@ -57,9 +48,7 @@ def do_add_sub(add_sub: Type[Matcher]):
|
|||||||
cur_platform = platform_manager[state["platform"]]
|
cur_platform = platform_manager[state["platform"]]
|
||||||
if cur_platform.has_target:
|
if cur_platform.has_target:
|
||||||
state["_prompt"] = (
|
state["_prompt"] = (
|
||||||
("1." + cur_platform.parse_target_promot + "\n2.")
|
("1." + cur_platform.parse_target_promot + "\n2.") if cur_platform.parse_target_promot else ""
|
||||||
if cur_platform.parse_target_promot
|
|
||||||
else ""
|
|
||||||
) + "请输入订阅用户的id\n查询id获取方法请回复:“查询”"
|
) + "请输入订阅用户的id\n查询id获取方法请回复:“查询”"
|
||||||
else:
|
else:
|
||||||
matcher.set_arg("raw_id", None) # type: ignore
|
matcher.set_arg("raw_id", None) # type: ignore
|
||||||
@ -81,9 +70,7 @@ def do_add_sub(add_sub: Type[Matcher]):
|
|||||||
image = "https://s3.bmp.ovh/imgs/2022/03/ab3cc45d83bd3dd3.jpg"
|
image = "https://s3.bmp.ovh/imgs/2022/03/ab3cc45d83bd3dd3.jpg"
|
||||||
msg.overwrite(
|
msg.overwrite(
|
||||||
SupportedAdapters.onebot_v11,
|
SupportedAdapters.onebot_v11,
|
||||||
MessageSegment.share(
|
MessageSegment.share(url=url, title=title, content=content, image=image),
|
||||||
url=url, title=title, content=content, image=image
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
await msg.reject()
|
await msg.reject()
|
||||||
platform = platform_manager[state["platform"]]
|
platform = platform_manager[state["platform"]]
|
||||||
@ -99,14 +86,12 @@ def do_add_sub(add_sub: Type[Matcher]):
|
|||||||
await add_sub.reject("id输入错误")
|
await add_sub.reject("id输入错误")
|
||||||
state["id"] = raw_id_text
|
state["id"] = raw_id_text
|
||||||
state["name"] = name
|
state["name"] = name
|
||||||
except (Platform.ParseTargetException):
|
except Platform.ParseTargetException:
|
||||||
await add_sub.reject("不能从你的输入中提取出id,请检查你输入的内容是否符合预期")
|
await add_sub.reject("不能从你的输入中提取出id,请检查你输入的内容是否符合预期")
|
||||||
else:
|
else:
|
||||||
await add_sub.send(
|
await add_sub.send(
|
||||||
"即将订阅的用户为:{} {} {}\n如有错误请输入“取消”重新订阅".format(
|
f"即将订阅的用户为:{state['platform']} {state['name']} {state['id']}\n如有错误请输入“取消”重新订阅"
|
||||||
state["platform"], state["name"], state["id"]
|
) # noqa: E501
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@add_sub.handle()
|
@add_sub.handle()
|
||||||
async def prepare_get_categories(matcher: Matcher, state: T_State):
|
async def prepare_get_categories(matcher: Matcher, state: T_State):
|
||||||
@ -125,7 +110,7 @@ def do_add_sub(add_sub: Type[Matcher]):
|
|||||||
if platform_manager[state["platform"]].categories:
|
if platform_manager[state["platform"]].categories:
|
||||||
for cat in raw_cats_text.split():
|
for cat in raw_cats_text.split():
|
||||||
if cat not in platform_manager[state["platform"]].reverse_category:
|
if cat not in platform_manager[state["platform"]].reverse_category:
|
||||||
await add_sub.reject("不支持 {}".format(cat))
|
await add_sub.reject(f"不支持 {cat}")
|
||||||
res.append(platform_manager[state["platform"]].reverse_category[cat])
|
res.append(platform_manager[state["platform"]].reverse_category[cat])
|
||||||
state["cats"] = res
|
state["cats"] = res
|
||||||
|
|
||||||
@ -135,14 +120,16 @@ def do_add_sub(add_sub: Type[Matcher]):
|
|||||||
matcher.set_arg("raw_tags", None) # type: ignore
|
matcher.set_arg("raw_tags", None) # type: ignore
|
||||||
state["tags"] = []
|
state["tags"] = []
|
||||||
return
|
return
|
||||||
state["_prompt"] = '请输入要订阅/屏蔽的标签(不含#号)\n多个标签请使用空格隔开\n订阅所有标签输入"全部标签"\n具体规则回复"详情"'
|
state["_prompt"] = "请输入要订阅/屏蔽的标签(不含#号)\n" "多个标签请使用空格隔开\n" '订阅所有标签输入"全部标签"\n' '具体规则回复"详情"' # noqa: E501
|
||||||
|
|
||||||
@add_sub.got("raw_tags", MessageTemplate("{_prompt}"), [handle_cancel])
|
@add_sub.got("raw_tags", MessageTemplate("{_prompt}"), [handle_cancel])
|
||||||
async def parser_tags(state: T_State, raw_tags: Message = Arg()):
|
async def parser_tags(state: T_State, raw_tags: Message = Arg()):
|
||||||
raw_tags_text = raw_tags.extract_plain_text()
|
raw_tags_text = raw_tags.extract_plain_text()
|
||||||
if raw_tags_text == "详情":
|
if raw_tags_text == "详情":
|
||||||
await add_sub.reject(
|
await add_sub.reject(
|
||||||
"订阅标签直接输入标签内容\n屏蔽标签请在标签名称前添加~号\n详见https://nonebot-bison.netlify.app/usage/#%E5%B9%B3%E5%8F%B0%E8%AE%A2%E9%98%85%E6%A0%87%E7%AD%BE-tag"
|
"订阅标签直接输入标签内容\n"
|
||||||
|
"屏蔽标签请在标签名称前添加~号\n"
|
||||||
|
"详见https://nonebot-bison.netlify.app/usage/#%E5%B9%B3%E5%8F%B0%E8%AE%A2%E9%98%85%E6%A0%87%E7%AD%BE-tag"
|
||||||
)
|
)
|
||||||
if raw_tags_text in ["全部标签", "全部", "全标签"]:
|
if raw_tags_text in ["全部标签", "全部", "全标签"]:
|
||||||
state["tags"] = []
|
state["tags"] = []
|
||||||
@ -150,9 +137,7 @@ def do_add_sub(add_sub: Type[Matcher]):
|
|||||||
state["tags"] = raw_tags_text.split()
|
state["tags"] = raw_tags_text.split()
|
||||||
|
|
||||||
@add_sub.handle()
|
@add_sub.handle()
|
||||||
async def add_sub_process(
|
async def add_sub_process(state: T_State, user: PlatformTarget = Arg("target_user_info")):
|
||||||
state: T_State, user: PlatformTarget = Arg("target_user_info")
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
await config.add_subscribe(
|
await config.add_subscribe(
|
||||||
user=user,
|
user=user,
|
||||||
|
@ -1,26 +1,22 @@
|
|||||||
from typing import Type
|
from nonebot.typing import T_State
|
||||||
|
|
||||||
from nonebot.matcher import Matcher
|
from nonebot.matcher import Matcher
|
||||||
from nonebot.params import Arg, EventPlainText
|
from nonebot.params import Arg, EventPlainText
|
||||||
from nonebot.typing import T_State
|
|
||||||
from nonebot_plugin_saa import MessageFactory, PlatformTarget
|
from nonebot_plugin_saa import MessageFactory, PlatformTarget
|
||||||
|
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..platform import platform_manager
|
|
||||||
from ..types import Category
|
from ..types import Category
|
||||||
from ..utils import parse_text
|
from ..utils import parse_text
|
||||||
|
from ..platform import platform_manager
|
||||||
from .utils import ensure_user_info, gen_handle_cancel
|
from .utils import ensure_user_info, gen_handle_cancel
|
||||||
|
|
||||||
|
|
||||||
def do_del_sub(del_sub: Type[Matcher]):
|
def do_del_sub(del_sub: type[Matcher]):
|
||||||
handle_cancel = gen_handle_cancel(del_sub, "删除中止")
|
handle_cancel = gen_handle_cancel(del_sub, "删除中止")
|
||||||
|
|
||||||
del_sub.handle()(ensure_user_info(del_sub))
|
del_sub.handle()(ensure_user_info(del_sub))
|
||||||
|
|
||||||
@del_sub.handle()
|
@del_sub.handle()
|
||||||
async def send_list(
|
async def send_list(state: T_State, user_info: PlatformTarget = Arg("target_user_info")):
|
||||||
state: T_State, user_info: PlatformTarget = Arg("target_user_info")
|
|
||||||
):
|
|
||||||
sub_list = await config.list_subscribe(user_info)
|
sub_list = await config.list_subscribe(user_info)
|
||||||
if not sub_list:
|
if not sub_list:
|
||||||
await del_sub.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅")
|
await del_sub.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅")
|
||||||
@ -31,22 +27,10 @@ def do_del_sub(del_sub: Type[Matcher]):
|
|||||||
"platform_name": sub.target.platform_name,
|
"platform_name": sub.target.platform_name,
|
||||||
"target": sub.target.target,
|
"target": sub.target.target,
|
||||||
}
|
}
|
||||||
res += "{} {} {} {}\n".format(
|
res += f"{index} {sub.target.platform_name} {sub.target.target_name} {sub.target.target}\n"
|
||||||
index,
|
|
||||||
sub.target.platform_name,
|
|
||||||
sub.target.target_name,
|
|
||||||
sub.target.target,
|
|
||||||
)
|
|
||||||
platform = platform_manager[sub.target.platform_name]
|
platform = platform_manager[sub.target.platform_name]
|
||||||
if platform.categories:
|
if platform.categories:
|
||||||
res += " [{}]".format(
|
res += " [{}]".format(", ".join(platform.categories[Category(x)] for x in sub.categories))
|
||||||
", ".join(
|
|
||||||
map(
|
|
||||||
lambda x: platform.categories[Category(x)],
|
|
||||||
sub.categories,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if platform.enable_tag:
|
if platform.enable_tag:
|
||||||
res += " {}".format(", ".join(sub.tags))
|
res += " {}".format(", ".join(sub.tags))
|
||||||
res += "\n"
|
res += "\n"
|
||||||
@ -62,7 +46,7 @@ def do_del_sub(del_sub: Type[Matcher]):
|
|||||||
try:
|
try:
|
||||||
index = int(index_str)
|
index = int(index_str)
|
||||||
await config.del_subscribe(user_info, **state["sub_table"][index])
|
await config.del_subscribe(user_info, **state["sub_table"][index])
|
||||||
except Exception as e:
|
except Exception:
|
||||||
await del_sub.reject("删除错误")
|
await del_sub.reject("删除错误")
|
||||||
else:
|
else:
|
||||||
await del_sub.finish("删除成功")
|
await del_sub.finish("删除成功")
|
||||||
|
@ -1,17 +1,15 @@
|
|||||||
from typing import Type
|
|
||||||
|
|
||||||
from nonebot.matcher import Matcher
|
|
||||||
from nonebot.params import Arg
|
from nonebot.params import Arg
|
||||||
|
from nonebot.matcher import Matcher
|
||||||
from nonebot_plugin_saa import MessageFactory, PlatformTarget
|
from nonebot_plugin_saa import MessageFactory, PlatformTarget
|
||||||
|
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..platform import platform_manager
|
|
||||||
from ..types import Category
|
from ..types import Category
|
||||||
from ..utils import parse_text
|
from ..utils import parse_text
|
||||||
from .utils import ensure_user_info
|
from .utils import ensure_user_info
|
||||||
|
from ..platform import platform_manager
|
||||||
|
|
||||||
|
|
||||||
def do_query_sub(query_sub: Type[Matcher]):
|
def do_query_sub(query_sub: type[Matcher]):
|
||||||
query_sub.handle()(ensure_user_info(query_sub))
|
query_sub.handle()(ensure_user_info(query_sub))
|
||||||
|
|
||||||
@query_sub.handle()
|
@query_sub.handle()
|
||||||
@ -19,19 +17,10 @@ def do_query_sub(query_sub: Type[Matcher]):
|
|||||||
sub_list = await config.list_subscribe(user_info)
|
sub_list = await config.list_subscribe(user_info)
|
||||||
res = "订阅的帐号为:\n"
|
res = "订阅的帐号为:\n"
|
||||||
for sub in sub_list:
|
for sub in sub_list:
|
||||||
res += "{} {} {}".format(
|
res += f"{sub.target.platform_name} {sub.target.target_name} {sub.target.target}"
|
||||||
# sub["target_type"], sub["target_name"], sub["target"]
|
|
||||||
sub.target.platform_name,
|
|
||||||
sub.target.target_name,
|
|
||||||
sub.target.target,
|
|
||||||
)
|
|
||||||
platform = platform_manager[sub.target.platform_name]
|
platform = platform_manager[sub.target.platform_name]
|
||||||
if platform.categories:
|
if platform.categories:
|
||||||
res += " [{}]".format(
|
res += " [{}]".format(", ".join(platform.categories[Category(x)] for x in sub.categories))
|
||||||
", ".join(
|
|
||||||
map(lambda x: platform.categories[Category(x)], sub.categories)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if platform.enable_tag:
|
if platform.enable_tag:
|
||||||
res += " {}".format(", ".join(sub.tags))
|
res += " {}".format(", ".join(sub.tags))
|
||||||
res += "\n"
|
res += "\n"
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
from typing import Annotated, Type
|
from typing import Annotated
|
||||||
|
|
||||||
from nonebot.adapters import Event
|
|
||||||
from nonebot.matcher import Matcher
|
|
||||||
from nonebot.params import Depends, EventPlainText, EventToMe
|
|
||||||
from nonebot.permission import SUPERUSER
|
|
||||||
from nonebot.rule import Rule
|
from nonebot.rule import Rule
|
||||||
|
from nonebot.adapters import Event
|
||||||
from nonebot.typing import T_State
|
from nonebot.typing import T_State
|
||||||
|
from nonebot.matcher import Matcher
|
||||||
|
from nonebot.permission import SUPERUSER
|
||||||
from nonebot_plugin_saa import extract_target
|
from nonebot_plugin_saa import extract_target
|
||||||
|
from nonebot.params import Depends, EventToMe, EventPlainText
|
||||||
|
|
||||||
from ..platform import platform_manager
|
from ..platform import platform_manager
|
||||||
from ..plugin_config import plugin_config
|
from ..plugin_config import plugin_config
|
||||||
@ -31,7 +31,7 @@ common_platform = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def gen_handle_cancel(matcher: Type[Matcher], message: str):
|
def gen_handle_cancel(matcher: type[Matcher], message: str):
|
||||||
async def _handle_cancel(text: Annotated[str, EventPlainText()]):
|
async def _handle_cancel(text: Annotated[str, EventPlainText()]):
|
||||||
if text == "取消":
|
if text == "取消":
|
||||||
await matcher.finish(message)
|
await matcher.finish(message)
|
||||||
@ -39,12 +39,10 @@ def gen_handle_cancel(matcher: Type[Matcher], message: str):
|
|||||||
return Depends(_handle_cancel)
|
return Depends(_handle_cancel)
|
||||||
|
|
||||||
|
|
||||||
def ensure_user_info(matcher: Type[Matcher]):
|
def ensure_user_info(matcher: type[Matcher]):
|
||||||
async def _check_user_info(state: T_State):
|
async def _check_user_info(state: T_State):
|
||||||
if not state.get("target_user_info"):
|
if not state.get("target_user_info"):
|
||||||
await matcher.finish(
|
await matcher.finish("No target_user_info set, this shouldn't happen, please issue")
|
||||||
"No target_user_info set, this shouldn't happen, please issue"
|
|
||||||
)
|
|
||||||
|
|
||||||
return _check_user_info
|
return _check_user_info
|
||||||
|
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
import difflib
|
import difflib
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import nonebot
|
import nonebot
|
||||||
from bs4 import BeautifulSoup as bs
|
|
||||||
from nonebot.log import default_format, logger
|
|
||||||
from nonebot.plugin import require
|
from nonebot.plugin import require
|
||||||
from nonebot_plugin_saa import Image, MessageSegmentFactory, Text
|
from bs4 import BeautifulSoup as bs
|
||||||
|
from nonebot.log import logger, default_format
|
||||||
|
from nonebot_plugin_saa import Text, Image, MessageSegmentFactory
|
||||||
|
|
||||||
from ..plugin_config import plugin_config
|
|
||||||
from .context import ProcessContext
|
|
||||||
from .http import http_client
|
from .http import http_client
|
||||||
|
from .context import ProcessContext
|
||||||
|
from ..plugin_config import plugin_config
|
||||||
from .scheduler_config import SchedulerConfig, scheduler
|
from .scheduler_config import SchedulerConfig, scheduler
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -30,7 +29,7 @@ class Singleton(type):
|
|||||||
|
|
||||||
def __call__(cls, *args, **kwargs):
|
def __call__(cls, *args, **kwargs):
|
||||||
if cls not in cls._instances:
|
if cls not in cls._instances:
|
||||||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
cls._instances[cls] = super().__call__(*args, **kwargs)
|
||||||
return cls._instances[cls]
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
|
||||||
@ -63,7 +62,7 @@ def html_to_text(html: str, query_dict: dict = {}) -> str:
|
|||||||
|
|
||||||
class Filter:
|
class Filter:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.level: Union[int, str] = "DEBUG"
|
self.level: int | str = "DEBUG"
|
||||||
|
|
||||||
def __call__(self, record):
|
def __call__(self, record):
|
||||||
module_name: str = record["name"]
|
module_name: str = record["name"]
|
||||||
@ -71,9 +70,7 @@ class Filter:
|
|||||||
if module:
|
if module:
|
||||||
module_name = getattr(module, "__module_name__", module_name)
|
module_name = getattr(module, "__module_name__", module_name)
|
||||||
record["name"] = module_name.split(".")[0]
|
record["name"] = module_name.split(".")[0]
|
||||||
levelno = (
|
levelno = logger.level(self.level).no if isinstance(self.level, str) else self.level
|
||||||
logger.level(self.level).no if isinstance(self.level, str) else self.level
|
|
||||||
)
|
|
||||||
nonebot_warning_level = logger.level("WARNING").no
|
nonebot_warning_level = logger.level("WARNING").no
|
||||||
return (
|
return (
|
||||||
record["level"].no >= levelno
|
record["level"].no >= levelno
|
||||||
@ -94,11 +91,7 @@ if plugin_config.bison_filter_log:
|
|||||||
)
|
)
|
||||||
config = nonebot.get_driver().config
|
config = nonebot.get_driver().config
|
||||||
logger.success("Muted info & success from nonebot")
|
logger.success("Muted info & success from nonebot")
|
||||||
default_filter.level = (
|
default_filter.level = ("DEBUG" if config.debug else "INFO") if config.log_level is None else config.log_level
|
||||||
("DEBUG" if config.debug else "INFO")
|
|
||||||
if config.log_level is None
|
|
||||||
else config.log_level
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def text_similarity(str1, str2) -> float:
|
def text_similarity(str1, str2) -> float:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
|
|
||||||
from httpx import AsyncClient, Response
|
from httpx import Response, AsyncClient
|
||||||
|
|
||||||
|
|
||||||
class ProcessContext:
|
class ProcessContext:
|
||||||
@ -33,8 +33,13 @@ class ProcessContext:
|
|||||||
res = []
|
res = []
|
||||||
for req in self.reqs:
|
for req in self.reqs:
|
||||||
if self._should_print_content(req):
|
if self._should_print_content(req):
|
||||||
log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} {req.text}"
|
log_content = (
|
||||||
|
f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} {req.text}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} b64encoded: {b64encode(req.content[:50]).decode()}"
|
log_content = (
|
||||||
|
f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} "
|
||||||
|
f"b64encoded: {b64encode(req.content[:50]).decode()}"
|
||||||
|
)
|
||||||
res.append(log_content)
|
res.append(log_content)
|
||||||
return res
|
return res
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import functools
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from ..plugin_config import plugin_config
|
from ..plugin_config import plugin_config
|
||||||
@ -10,7 +8,6 @@ http_args = {
|
|||||||
http_headers = {"user-agent": plugin_config.bison_ua}
|
http_headers = {"user-agent": plugin_config.bison_ua}
|
||||||
|
|
||||||
|
|
||||||
@functools.wraps(httpx.AsyncClient)
|
|
||||||
def http_client(*args, **kwargs):
|
def http_client(*args, **kwargs):
|
||||||
if headers := kwargs.get("headers"):
|
if headers := kwargs.get("headers"):
|
||||||
new_headers = http_headers.copy()
|
new_headers = http_headers.copy()
|
||||||
@ -18,7 +15,4 @@ def http_client(*args, **kwargs):
|
|||||||
kwargs["headers"] = new_headers
|
kwargs["headers"] = new_headers
|
||||||
else:
|
else:
|
||||||
kwargs["headers"] = http_headers
|
kwargs["headers"] = http_headers
|
||||||
return httpx.AsyncClient(*args, **kwargs)
|
return httpx.AsyncClient(*args, **kwargs, **http_args)
|
||||||
|
|
||||||
|
|
||||||
http_client = functools.partial(http_client, **http_args)
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Literal, Type
|
from typing import Literal
|
||||||
|
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
@ -7,7 +7,6 @@ from .http import http_client
|
|||||||
|
|
||||||
|
|
||||||
class SchedulerConfig:
|
class SchedulerConfig:
|
||||||
|
|
||||||
schedule_type: Literal["date", "interval", "cron"]
|
schedule_type: Literal["date", "interval", "cron"]
|
||||||
schedule_setting: dict
|
schedule_setting: dict
|
||||||
name: str
|
name: str
|
||||||
@ -25,9 +24,7 @@ class SchedulerConfig:
|
|||||||
return self.default_http_client
|
return self.default_http_client
|
||||||
|
|
||||||
|
|
||||||
def scheduler(
|
def scheduler(schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict) -> type[SchedulerConfig]:
|
||||||
schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict
|
|
||||||
) -> Type[SchedulerConfig]:
|
|
||||||
return type(
|
return type(
|
||||||
"AnonymousScheduleConfig",
|
"AnonymousScheduleConfig",
|
||||||
(SchedulerConfig,),
|
(SchedulerConfig,),
|
||||||
|
@ -85,7 +85,7 @@ line-length = 120
|
|||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 120
|
line-length = 118
|
||||||
target-version = ["py310", "py311"]
|
target-version = ["py310", "py311"]
|
||||||
include = '\.pyi?$'
|
include = '\.pyi?$'
|
||||||
extend-exclude = '''
|
extend-exclude = '''
|
||||||
|
Loading…
x
Reference in New Issue
Block a user