♻️ refactor client of scheduler_config

This commit is contained in:
felinae98
2024-05-28 20:59:39 +08:00
parent c21b24b385
commit 2fc11a9653
30 changed files with 185 additions and 143 deletions
+13 -9
View File
@@ -74,7 +74,8 @@ class Arknights(NewMessage):
return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[BulletinListItem]:
raw_data = await self.client.get("https://ak-webview.hypergryph.com/api/game/bulletinList?target=IOS")
client = await self.ctx.get_client()
raw_data = await client.get("https://ak-webview.hypergryph.com/api/game/bulletinList?target=IOS")
return type_validate_python(ArkBulletinListResponse, raw_data.json()).data.list
def get_id(self, post: BulletinListItem) -> Any:
@@ -91,9 +92,8 @@ class Arknights(NewMessage):
return Category(1)
async def parse(self, raw_post: BulletinListItem) -> Post:
raw_data = await self.client.get(
f"https://ak-webview.hypergryph.com/api/game/bulletin/{self.get_id(post=raw_post)}"
)
client = await self.ctx.get_client()
raw_data = await client.get(f"https://ak-webview.hypergryph.com/api/game/bulletin/{self.get_id(post=raw_post)}")
data = type_validate_python(ArkBulletinResponse, raw_data.json()).data
def title_escape(text: str) -> str:
@@ -136,8 +136,9 @@ class AkVersion(StatusChange):
return "明日方舟游戏信息"
async def get_status(self, _):
res_ver = await self.client.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version")
res_preanounce = await self.client.get(
client = await self.ctx.get_client()
res_ver = await client.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version")
res_preanounce = await client.get(
"https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json"
)
res = res_ver.json()
@@ -179,7 +180,8 @@ class MonsterSiren(NewMessage):
return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]:
raw_data = await self.client.get("https://monster-siren.hypergryph.com/api/news")
client = await self.ctx.get_client()
raw_data = await client.get("https://monster-siren.hypergryph.com/api/news")
return raw_data.json()["data"]["list"]
def get_id(self, post: RawPost) -> Any:
@@ -192,8 +194,9 @@ class MonsterSiren(NewMessage):
return Category(3)
async def parse(self, raw_post: RawPost) -> Post:
client = await self.ctx.get_client()
url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}'
res = await self.client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}')
res = await client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}')
raw_data = res.json()
content = raw_data["data"]["content"]
content = content.replace("</p>", "</p>\n")
@@ -226,7 +229,8 @@ class TerraHistoricusComic(NewMessage):
return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]:
raw_data = await self.client.get("https://terra-historicus.hypergryph.com/api/recentUpdate")
client = await self.ctx.get_client()
raw_data = await client.get("https://terra-historicus.hypergryph.com/api/recentUpdate")
return raw_data.json()["data"]
def get_id(self, post: RawPost) -> Any:
+21 -29
View File
@@ -1,6 +1,5 @@
import re
import json
from abc import ABC
from copy import deepcopy
from enum import Enum, unique
from typing_extensions import Self
@@ -13,6 +12,7 @@ from pydantic import Field, BaseModel
from nonebot.compat import PYDANTIC_V2, ConfigDict, type_validate_json, type_validate_python
from nonebot_bison.compat import model_rebuild
from nonebot_bison.utils.scheduler_config import ClientManager
from ..post import Post
from ..types import Tag, Target, RawPost, ApiError, Category
@@ -104,7 +104,7 @@ model_rebuild_recurse(UserAPI)
model_rebuild_recurse(PostAPI)
class BilibiliClient:
class BilibiliClient(ClientManager):
_client: AsyncClient
_refresh_time: datetime
cookie_expire_time = timedelta(hours=5)
@@ -124,37 +124,27 @@ class BilibiliClient:
if datetime.now() - self._refresh_time > self.cookie_expire_time:
await self._init_session()
async def get_client(self) -> AsyncClient:
async def get_client(self, target: Target | None) -> AsyncClient:
await self._refresh_client()
return self._client
async def get_query_name_client(self) -> AsyncClient:
await self._refresh_client()
return self._client
bilibili_client = BilibiliClient()
class BaseSchedConf(ABC, SchedulerConfig):
schedule_type = "interval"
bilibili_client: BilibiliClient
def __init__(self):
super().__init__()
self.bilibili_client = bilibili_client
async def get_client(self, _: Target) -> AsyncClient:
return await self.bilibili_client.get_client()
async def get_query_name_client(self) -> AsyncClient:
return await self.bilibili_client.get_client()
class BilibiliSchedConf(BaseSchedConf):
class BilibiliSchedConf(SchedulerConfig):
name = "bilibili.com"
schedule_type = "interval"
schedule_setting = {"seconds": 10}
client_man = BilibiliClient
class BililiveSchedConf(BaseSchedConf):
class BililiveSchedConf(SchedulerConfig):
name = "live.bilibili.com"
schedule_type = "interval"
schedule_setting = {"seconds": 3}
client_man = BilibiliClient
class Bilibili(NewMessage):
@@ -198,8 +188,9 @@ class Bilibili(NewMessage):
)
async def get_sub_list(self, target: Target) -> list[DynRawPost]:
client = await self.ctx.get_client()
params = {"host_uid": target, "offset": 0, "need_top": 0}
res = await self.client.get(
res = await client.get(
"https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history",
params=params,
timeout=4.0,
@@ -428,8 +419,9 @@ class Bilibililive(StatusChange):
)
async def batch_get_status(self, targets: list[Target]) -> list[Info]:
client = await self.ctx.get_client()
# https://github.com/SocialSisterYi/bilibili-API-collect/blob/master/docs/live/info.md#批量查询直播间状态
res = await self.client.get(
res = await client.get(
"https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids",
params={"uids[]": targets},
timeout=4.0,
@@ -520,7 +512,8 @@ class BilibiliBangumi(StatusChange):
)
async def get_status(self, target: Target):
res = await self.client.get(
client = await self.ctx.get_client()
res = await client.get(
self._url,
params={"media_id": target},
timeout=4.0,
@@ -542,9 +535,8 @@ class BilibiliBangumi(StatusChange):
return []
async def parse(self, raw_post: RawPost) -> Post:
detail_res = await self.client.get(
f'https://api.bilibili.com/pgc/view/web/season?season_id={raw_post["season_id"]}'
)
client = await self.ctx.get_client()
detail_res = await client.get(f'https://api.bilibili.com/pgc/view/web/season?season_id={raw_post["season_id"]}')
detail_dict = detail_res.json()
lastest_episode = None
for episode in detail_dict["result"]["episodes"][::-1]:
+2 -1
View File
@@ -24,7 +24,8 @@ class FF14(NewMessage):
return "最终幻想XIV官方公告"
async def get_sub_list(self, _) -> list[RawPost]:
raw_data = await self.client.get(
client = await self.ctx.get_client()
raw_data = await client.get(
"https://cqnews.web.sdo.com/api/news/newsList?gameCode=ff&CategoryCode=5309,5310,5311,5312,5313&pageIndex=0&pageSize=5"
)
return raw_data.json()["Data"]
+4 -2
View File
@@ -47,7 +47,8 @@ class NcmArtist(NewMessage):
raise cls.ParseTargetException("正确格式:\n1. 歌手数字ID\n2. https://music.163.com/#/artist?id=xxxx")
async def get_sub_list(self, target: Target) -> list[RawPost]:
res = await self.client.get(
client = await self.ctx.get_client()
res = await client.get(
f"https://music.163.com/api/artist/albums/{target}",
headers={"Referer": "https://music.163.com/"},
)
@@ -106,7 +107,8 @@ class NcmRadio(NewMessage):
)
async def get_sub_list(self, target: Target) -> list[RawPost]:
res = await self.client.post(
client = await self.ctx.get_client()
res = await client.post(
"http://music.163.com/api/dj/program/byradio",
headers={"Referer": "https://music.163.com/"},
data={"radioId": target, "limit": 1000, "offset": 0},
+6 -8
View File
@@ -92,7 +92,6 @@ class Platform(metaclass=PlatformABCMeta, base=True):
platform_name: str
parse_target_promot: str | None = None
registry: list[type["Platform"]]
client: AsyncClient
reverse_category: dict[str, Category]
use_batch: bool = False
# TODO: 限定可使用的theme名称
@@ -121,9 +120,8 @@ class Platform(metaclass=PlatformABCMeta, base=True):
"actually function called"
return await self.parse(raw_post)
def __init__(self, context: ProcessContext, client: AsyncClient):
def __init__(self, context: ProcessContext):
super().__init__()
self.client = client
self.ctx = context
class ParseTargetException(Exception):
@@ -225,8 +223,8 @@ class Platform(metaclass=PlatformABCMeta, base=True):
class MessageProcess(Platform, abstract=True):
"General message process fetch, parse, filter progress"
def __init__(self, ctx: ProcessContext, client: AsyncClient):
super().__init__(ctx, client)
def __init__(self, ctx: ProcessContext):
super().__init__(ctx)
self.parse_cache: dict[Any, Post] = {}
@abstractmethod
@@ -463,11 +461,11 @@ def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]:
if platform.scheduler != scheduler:
raise RuntimeError(f"Platform scheduler for {platform_name} not fit")
def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient):
Platform.__init__(self, ctx, client)
def __init__(self: "NoTargetGroup", ctx: ProcessContext):
Platform.__init__(self, ctx)
self.platform_obj_list = []
for platform_class in self.platform_list:
self.platform_obj_list.append(platform_class(ctx, client))
self.platform_obj_list.append(platform_class(ctx))
def __str__(self: "NoTargetGroup") -> str:
return "[" + " ".join(x.name for x in self.platform_list) + "]"
+2 -1
View File
@@ -46,7 +46,8 @@ class Rss(NewMessage):
return post.id
async def get_sub_list(self, target: Target) -> list[RawPost]:
res = await self.client.get(target, timeout=10.0)
client = await self.ctx.get_client()
res = await client.get(target, timeout=10.0)
feed = feedparser.parse(res)
entries = feed.entries
for entry in entries:
+4 -2
View File
@@ -78,8 +78,9 @@ class Weibo(NewMessage):
raise cls.ParseTargetException(prompt="正确格式:\n1. 用户数字UID\n2. https://weibo.com/u/xxxx")
async def get_sub_list(self, target: Target) -> list[RawPost]:
client = await self.ctx.get_client()
params = {"containerid": "107603" + target}
res = await self.client.get("https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0)
res = await client.get("https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0)
res_data = json.loads(res.text)
if not res_data["ok"] and res_data["msg"] != "这里还没有内容":
raise ApiError(res.request.url)
@@ -149,7 +150,8 @@ class Weibo(NewMessage):
async def _get_long_weibo(self, weibo_id: str) -> dict:
try:
weibo_info = await self.client.get(
client = await self.ctx.get_client()
weibo_info = await client.get(
"https://m.weibo.cn/statuses/show",
params={"id": weibo_id},
headers=_HEADER,