♻️ refactor client of scheduler_config

This commit is contained in:
felinae98
2024-05-28 20:59:39 +08:00
parent c21b24b385
commit 2fc11a9653
30 changed files with 185 additions and 143 deletions
+1 -1
View File
@@ -7,6 +7,6 @@ async def check_sub_target(platform_name: str, target: Target):
platform = platform_manager[platform_name]
scheduler_conf_class = platform.scheduler
scheduler = scheduler_dict[scheduler_conf_class]
client = await scheduler.scheduler_config_obj.get_query_name_client()
client = await scheduler.client_mgr.get_query_name_client()
return await platform_manager[platform_name].get_target_name(client, target)
+13 -9
View File
@@ -74,7 +74,8 @@ class Arknights(NewMessage):
return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[BulletinListItem]:
raw_data = await self.client.get("https://ak-webview.hypergryph.com/api/game/bulletinList?target=IOS")
client = await self.ctx.get_client()
raw_data = await client.get("https://ak-webview.hypergryph.com/api/game/bulletinList?target=IOS")
return type_validate_python(ArkBulletinListResponse, raw_data.json()).data.list
def get_id(self, post: BulletinListItem) -> Any:
@@ -91,9 +92,8 @@ class Arknights(NewMessage):
return Category(1)
async def parse(self, raw_post: BulletinListItem) -> Post:
raw_data = await self.client.get(
f"https://ak-webview.hypergryph.com/api/game/bulletin/{self.get_id(post=raw_post)}"
)
client = await self.ctx.get_client()
raw_data = await client.get(f"https://ak-webview.hypergryph.com/api/game/bulletin/{self.get_id(post=raw_post)}")
data = type_validate_python(ArkBulletinResponse, raw_data.json()).data
def title_escape(text: str) -> str:
@@ -136,8 +136,9 @@ class AkVersion(StatusChange):
return "明日方舟游戏信息"
async def get_status(self, _):
res_ver = await self.client.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version")
res_preanounce = await self.client.get(
client = await self.ctx.get_client()
res_ver = await client.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version")
res_preanounce = await client.get(
"https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json"
)
res = res_ver.json()
@@ -179,7 +180,8 @@ class MonsterSiren(NewMessage):
return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]:
raw_data = await self.client.get("https://monster-siren.hypergryph.com/api/news")
client = await self.ctx.get_client()
raw_data = await client.get("https://monster-siren.hypergryph.com/api/news")
return raw_data.json()["data"]["list"]
def get_id(self, post: RawPost) -> Any:
@@ -192,8 +194,9 @@ class MonsterSiren(NewMessage):
return Category(3)
async def parse(self, raw_post: RawPost) -> Post:
client = await self.ctx.get_client()
url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}'
res = await self.client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}')
res = await client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}')
raw_data = res.json()
content = raw_data["data"]["content"]
content = content.replace("</p>", "</p>\n")
@@ -226,7 +229,8 @@ class TerraHistoricusComic(NewMessage):
return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]:
raw_data = await self.client.get("https://terra-historicus.hypergryph.com/api/recentUpdate")
client = await self.ctx.get_client()
raw_data = await client.get("https://terra-historicus.hypergryph.com/api/recentUpdate")
return raw_data.json()["data"]
def get_id(self, post: RawPost) -> Any:
+21 -29
View File
@@ -1,6 +1,5 @@
import re
import json
from abc import ABC
from copy import deepcopy
from enum import Enum, unique
from typing_extensions import Self
@@ -13,6 +12,7 @@ from pydantic import Field, BaseModel
from nonebot.compat import PYDANTIC_V2, ConfigDict, type_validate_json, type_validate_python
from nonebot_bison.compat import model_rebuild
from nonebot_bison.utils.scheduler_config import ClientManager
from ..post import Post
from ..types import Tag, Target, RawPost, ApiError, Category
@@ -104,7 +104,7 @@ model_rebuild_recurse(UserAPI)
model_rebuild_recurse(PostAPI)
class BilibiliClient:
class BilibiliClient(ClientManager):
_client: AsyncClient
_refresh_time: datetime
cookie_expire_time = timedelta(hours=5)
@@ -124,37 +124,27 @@ class BilibiliClient:
if datetime.now() - self._refresh_time > self.cookie_expire_time:
await self._init_session()
async def get_client(self) -> AsyncClient:
async def get_client(self, target: Target | None) -> AsyncClient:
await self._refresh_client()
return self._client
async def get_query_name_client(self) -> AsyncClient:
await self._refresh_client()
return self._client
bilibili_client = BilibiliClient()
class BaseSchedConf(ABC, SchedulerConfig):
schedule_type = "interval"
bilibili_client: BilibiliClient
def __init__(self):
super().__init__()
self.bilibili_client = bilibili_client
async def get_client(self, _: Target) -> AsyncClient:
return await self.bilibili_client.get_client()
async def get_query_name_client(self) -> AsyncClient:
return await self.bilibili_client.get_client()
class BilibiliSchedConf(BaseSchedConf):
class BilibiliSchedConf(SchedulerConfig):
name = "bilibili.com"
schedule_type = "interval"
schedule_setting = {"seconds": 10}
client_man = BilibiliClient
class BililiveSchedConf(BaseSchedConf):
class BililiveSchedConf(SchedulerConfig):
name = "live.bilibili.com"
schedule_type = "interval"
schedule_setting = {"seconds": 3}
client_man = BilibiliClient
class Bilibili(NewMessage):
@@ -198,8 +188,9 @@ class Bilibili(NewMessage):
)
async def get_sub_list(self, target: Target) -> list[DynRawPost]:
client = await self.ctx.get_client()
params = {"host_uid": target, "offset": 0, "need_top": 0}
res = await self.client.get(
res = await client.get(
"https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history",
params=params,
timeout=4.0,
@@ -428,8 +419,9 @@ class Bilibililive(StatusChange):
)
async def batch_get_status(self, targets: list[Target]) -> list[Info]:
client = await self.ctx.get_client()
# https://github.com/SocialSisterYi/bilibili-API-collect/blob/master/docs/live/info.md#批量查询直播间状态
res = await self.client.get(
res = await client.get(
"https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids",
params={"uids[]": targets},
timeout=4.0,
@@ -520,7 +512,8 @@ class BilibiliBangumi(StatusChange):
)
async def get_status(self, target: Target):
res = await self.client.get(
client = await self.ctx.get_client()
res = await client.get(
self._url,
params={"media_id": target},
timeout=4.0,
@@ -542,9 +535,8 @@ class BilibiliBangumi(StatusChange):
return []
async def parse(self, raw_post: RawPost) -> Post:
detail_res = await self.client.get(
f'https://api.bilibili.com/pgc/view/web/season?season_id={raw_post["season_id"]}'
)
client = await self.ctx.get_client()
detail_res = await client.get(f'https://api.bilibili.com/pgc/view/web/season?season_id={raw_post["season_id"]}')
detail_dict = detail_res.json()
lastest_episode = None
for episode in detail_dict["result"]["episodes"][::-1]:
+2 -1
View File
@@ -24,7 +24,8 @@ class FF14(NewMessage):
return "最终幻想XIV官方公告"
async def get_sub_list(self, _) -> list[RawPost]:
raw_data = await self.client.get(
client = await self.ctx.get_client()
raw_data = await client.get(
"https://cqnews.web.sdo.com/api/news/newsList?gameCode=ff&CategoryCode=5309,5310,5311,5312,5313&pageIndex=0&pageSize=5"
)
return raw_data.json()["Data"]
+4 -2
View File
@@ -47,7 +47,8 @@ class NcmArtist(NewMessage):
raise cls.ParseTargetException("正确格式:\n1. 歌手数字ID\n2. https://music.163.com/#/artist?id=xxxx")
async def get_sub_list(self, target: Target) -> list[RawPost]:
res = await self.client.get(
client = await self.ctx.get_client()
res = await client.get(
f"https://music.163.com/api/artist/albums/{target}",
headers={"Referer": "https://music.163.com/"},
)
@@ -106,7 +107,8 @@ class NcmRadio(NewMessage):
)
async def get_sub_list(self, target: Target) -> list[RawPost]:
res = await self.client.post(
client = await self.ctx.get_client()
res = await client.post(
"http://music.163.com/api/dj/program/byradio",
headers={"Referer": "https://music.163.com/"},
data={"radioId": target, "limit": 1000, "offset": 0},
+6 -8
View File
@@ -92,7 +92,6 @@ class Platform(metaclass=PlatformABCMeta, base=True):
platform_name: str
parse_target_promot: str | None = None
registry: list[type["Platform"]]
client: AsyncClient
reverse_category: dict[str, Category]
use_batch: bool = False
# TODO: 限定可使用的theme名称
@@ -121,9 +120,8 @@ class Platform(metaclass=PlatformABCMeta, base=True):
"actually function called"
return await self.parse(raw_post)
def __init__(self, context: ProcessContext, client: AsyncClient):
def __init__(self, context: ProcessContext):
super().__init__()
self.client = client
self.ctx = context
class ParseTargetException(Exception):
@@ -225,8 +223,8 @@ class Platform(metaclass=PlatformABCMeta, base=True):
class MessageProcess(Platform, abstract=True):
"General message process fetch, parse, filter progress"
def __init__(self, ctx: ProcessContext, client: AsyncClient):
super().__init__(ctx, client)
def __init__(self, ctx: ProcessContext):
super().__init__(ctx)
self.parse_cache: dict[Any, Post] = {}
@abstractmethod
@@ -463,11 +461,11 @@ def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]:
if platform.scheduler != scheduler:
raise RuntimeError(f"Platform scheduler for {platform_name} not fit")
def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient):
Platform.__init__(self, ctx, client)
def __init__(self: "NoTargetGroup", ctx: ProcessContext):
Platform.__init__(self, ctx)
self.platform_obj_list = []
for platform_class in self.platform_list:
self.platform_obj_list.append(platform_class(ctx, client))
self.platform_obj_list.append(platform_class(ctx))
def __str__(self: "NoTargetGroup") -> str:
return "[" + " ".join(x.name for x in self.platform_list) + "]"
+2 -1
View File
@@ -46,7 +46,8 @@ class Rss(NewMessage):
return post.id
async def get_sub_list(self, target: Target) -> list[RawPost]:
res = await self.client.get(target, timeout=10.0)
client = await self.ctx.get_client()
res = await client.get(target, timeout=10.0)
feed = feedparser.parse(res)
entries = feed.entries
for entry in entries:
+4 -2
View File
@@ -78,8 +78,9 @@ class Weibo(NewMessage):
raise cls.ParseTargetException(prompt="正确格式:\n1. 用户数字UID\n2. https://weibo.com/u/xxxx")
async def get_sub_list(self, target: Target) -> list[RawPost]:
client = await self.ctx.get_client()
params = {"containerid": "107603" + target}
res = await self.client.get("https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0)
res = await client.get("https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0)
res_data = json.loads(res.text)
if not res_data["ok"] and res_data["msg"] != "这里还没有内容":
raise ApiError(res.request.url)
@@ -149,7 +150,8 @@ class Weibo(NewMessage):
async def _get_long_weibo(self, weibo_id: str) -> dict:
try:
weibo_info = await self.client.get(
client = await self.ctx.get_client()
weibo_info = await client.get(
"https://m.weibo.cn/statuses/show",
params={"id": weibo_id},
headers=_HEADER,
+6 -4
View File
@@ -5,6 +5,8 @@ from nonebot.log import logger
from nonebot_plugin_apscheduler import scheduler
from nonebot_plugin_saa.utils.exceptions import NoBotFound
from nonebot_bison.utils.scheduler_config import ClientManager
from ..config import config
from ..send import send_msgs
from ..types import Target, SubUnit
@@ -24,6 +26,7 @@ class Scheduler:
schedulable_list: list[Schedulable] # for load weigth from db
batch_api_target_cache: dict[str, dict[Target, list[Target]]] # platform_name -> (target -> [target])
batch_platform_name_targets_cache: dict[str, list[Target]]
client_mgr: ClientManager
def __init__(
self,
@@ -36,6 +39,7 @@ class Scheduler:
logger.error(f"scheduler config [{self.name}] not found, exiting")
raise RuntimeError(f"{self.name} not found")
self.scheduler_config = scheduler_config
self.client_mgr = scheduler_config.client_mgr()
self.scheduler_config_obj = self.scheduler_config()
self.schedulable_list = []
@@ -83,16 +87,14 @@ class Scheduler:
return cur_max_schedulable
async def exec_fetch(self):
context = ProcessContext()
if not (schedulable := await self.get_next_schedulable()):
return
logger.trace(f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}")
client = await self.scheduler_config_obj.get_client(schedulable.target)
context.register_to_client(client)
context = ProcessContext(self.client_mgr)
try:
platform_obj = platform_manager[schedulable.platform_name](context, client)
platform_obj = platform_manager[schedulable.platform_name](context)
if schedulable.use_batch:
batch_targets = self.batch_api_target_cache[schedulable.platform_name][schedulable.target]
sub_units = []
+2 -1
View File
@@ -42,11 +42,12 @@ class BasicTheme(Theme):
if urls:
text += "\n".join(urls)
client = await post.platform.ctx.get_client_for_static()
msgs: list[MessageSegmentFactory] = [Text(text)]
if post.images:
pics = post.images
if is_pics_mergable(pics):
pics = await pic_merge(list(pics), post.platform.client)
pics = await pic_merge(list(pics), client)
msgs.extend(map(Image, pics))
return msgs
+2 -1
View File
@@ -29,11 +29,12 @@ class BriefTheme(Theme):
if urls:
text += "\n".join(urls)
client = await post.platform.ctx.get_client_for_static()
msgs: list[MessageSegmentFactory] = [Text(text)]
if post.images:
pics = post.images
if is_pics_mergable(pics):
pics = await pic_merge(list(pics), post.platform.client)
pics = await pic_merge(list(pics), client)
msgs.append(Image(pics[0]))
return msgs
+2 -1
View File
@@ -54,9 +54,10 @@ class Ht2iTheme(Theme):
msgs.append(Text("\n".join(urls)))
if post.images:
client = await post.platform.ctx.get_client_for_static()
pics = post.images
if is_pics_mergable(pics):
pics = await pic_merge(list(pics), post.platform.client)
pics = await pic_merge(list(pics), client)
msgs.extend(map(Image, pics))
return msgs
+3 -1
View File
@@ -11,14 +11,16 @@ from nonebot_plugin_saa import Text, Image, MessageSegmentFactory
from .http import http_client
from .context import ProcessContext
from ..plugin_config import plugin_config
from .scheduler_config import SchedulerConfig, scheduler
from .image import pic_merge, text_to_image, is_pics_mergable, pic_url_to_image
from .scheduler_config import ClientManager, SchedulerConfig, DefaultClientManager, scheduler
__all__ = [
"http_client",
"Singleton",
"parse_text",
"ProcessContext",
"ClientManager",
"DefaultClientManager",
"html_to_text",
"SchedulerConfig",
"scheduler",
+23 -4
View File
@@ -2,19 +2,25 @@ from base64 import b64encode
from httpx import Response, AsyncClient
from nonebot_bison.types import Target
from .scheduler_config import ClientManager
class ProcessContext:
reqs: list[Response]
_client_mgr: ClientManager
def __init__(self) -> None:
def __init__(self, client_mgr: ClientManager) -> None:
self.reqs = []
self._client_mgr = client_mgr
def log_response(self, resp: Response):
def _log_response(self, resp: Response):
self.reqs.append(resp)
def register_to_client(self, client: AsyncClient):
def _register_to_client(self, client: AsyncClient):
async def _log_to_ctx(r: Response):
self.log_response(r)
self._log_response(r)
hooks = {
"response": [_log_to_ctx],
@@ -41,3 +47,16 @@ class ProcessContext:
)
res.append(log_content)
return res
async def get_client(self, target: Target | None = None) -> AsyncClient:
client = await self._client_mgr.get_client(target)
self._register_to_client(client)
return client
async def get_client_for_static(self) -> AsyncClient:
client = await self._client_mgr.get_client_for_static()
self._register_to_client(client)
return client
async def refresh_client(self):
await self._client_mgr.refresh_client()
+24 -6
View File
@@ -1,3 +1,4 @@
from abc import ABC
from typing import Literal
from httpx import AsyncClient
@@ -6,10 +7,32 @@ from ..types import Target
from .http import http_client
class ClientManager(ABC):
async def get_client(self, target: Target | None) -> AsyncClient: ...
async def get_client_for_static(self) -> AsyncClient: ...
async def get_query_name_client(self) -> AsyncClient: ...
async def refresh_client(self): ...
class DefaultClientManager(ClientManager):
async def get_client(self, target: Target | None) -> AsyncClient:
return http_client()
async def get_client_for_static(self) -> AsyncClient:
return http_client()
async def get_query_name_client(self) -> AsyncClient:
return http_client()
class SchedulerConfig:
schedule_type: Literal["date", "interval", "cron"]
schedule_setting: dict
name: str
client_mgr: type[ClientManager] = DefaultClientManager
require_browser: bool = False
def __str__(self):
@@ -18,12 +41,6 @@ class SchedulerConfig:
def __init__(self):
self.default_http_client = http_client()
async def get_client(self, target: Target) -> AsyncClient:
return self.default_http_client
async def get_query_name_client(self) -> AsyncClient:
return self.default_http_client
def scheduler(schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict) -> type[SchedulerConfig]:
return type(
@@ -32,5 +49,6 @@ def scheduler(schedule_type: Literal["date", "interval", "cron"], schedule_setti
{
"schedule_type": schedule_type,
"schedule_setting": schedule_setting,
"client_mgr": ClientManager,
},
)