From 2fc11a965300673d01cdae2e1fa1c157e1da2484 Mon Sep 17 00:00:00 2001 From: felinae98 <731499577@qq.com> Date: Tue, 28 May 2024 20:59:39 +0800 Subject: [PATCH] :recycle: refactor client of scheduler_config --- nonebot_bison/apis.py | 2 +- nonebot_bison/platform/arknights.py | 22 +++++---- nonebot_bison/platform/bilibili.py | 50 +++++++++------------ nonebot_bison/platform/ff14.py | 3 +- nonebot_bison/platform/ncm.py | 6 ++- nonebot_bison/platform/platform.py | 14 +++--- nonebot_bison/platform/rss.py | 3 +- nonebot_bison/platform/weibo.py | 6 ++- nonebot_bison/scheduler/scheduler.py | 10 +++-- nonebot_bison/theme/themes/basic/build.py | 3 +- nonebot_bison/theme/themes/brief/build.py | 3 +- nonebot_bison/theme/themes/ht2i/build.py | 3 +- nonebot_bison/utils/__init__.py | 4 +- nonebot_bison/utils/context.py | 27 +++++++++-- nonebot_bison/utils/scheduler_config.py | 30 ++++++++++--- tests/platforms/test_arknights.py | 18 ++++---- tests/platforms/test_bilibili.py | 5 ++- tests/platforms/test_bilibili_bangumi.py | 5 ++- tests/platforms/test_bilibili_live.py | 26 ++--------- tests/platforms/test_ff14.py | 5 ++- tests/platforms/test_ncm_artist.py | 5 ++- tests/platforms/test_ncm_radio.py | 5 ++- tests/platforms/test_platform.py | 38 +++++++++------- tests/platforms/test_platform_tag_filter.py | 7 +-- tests/platforms/test_rss.py | 5 ++- tests/platforms/test_weibo.py | 3 +- tests/post/test_generate.py | 7 +-- tests/scheduler/test_scheduler.py | 4 +- tests/test_context.py | 5 ++- tests/theme/test_themes.py | 4 +- 30 files changed, 185 insertions(+), 143 deletions(-) diff --git a/nonebot_bison/apis.py b/nonebot_bison/apis.py index 6d5130e..d294a5a 100644 --- a/nonebot_bison/apis.py +++ b/nonebot_bison/apis.py @@ -7,6 +7,6 @@ async def check_sub_target(platform_name: str, target: Target): platform = platform_manager[platform_name] scheduler_conf_class = platform.scheduler scheduler = scheduler_dict[scheduler_conf_class] - client = await scheduler.scheduler_config_obj.get_query_name_client() + client = await scheduler.client_mgr.get_query_name_client() return await platform_manager[platform_name].get_target_name(client, target) diff --git a/nonebot_bison/platform/arknights.py b/nonebot_bison/platform/arknights.py index e931d69..e5b4168 100644 --- a/nonebot_bison/platform/arknights.py +++ b/nonebot_bison/platform/arknights.py @@ -74,7 +74,8 @@ class Arknights(NewMessage): return "明日方舟游戏信息" async def get_sub_list(self, _) -> list[BulletinListItem]: - raw_data = await self.client.get("https://ak-webview.hypergryph.com/api/game/bulletinList?target=IOS") + client = await self.ctx.get_client() + raw_data = await client.get("https://ak-webview.hypergryph.com/api/game/bulletinList?target=IOS") return type_validate_python(ArkBulletinListResponse, raw_data.json()).data.list def get_id(self, post: BulletinListItem) -> Any: @@ -91,9 +92,8 @@ class Arknights(NewMessage): return Category(1) async def parse(self, raw_post: BulletinListItem) -> Post: - raw_data = await self.client.get( - f"https://ak-webview.hypergryph.com/api/game/bulletin/{self.get_id(post=raw_post)}" - ) + client = await self.ctx.get_client() + raw_data = await client.get(f"https://ak-webview.hypergryph.com/api/game/bulletin/{self.get_id(post=raw_post)}") data = type_validate_python(ArkBulletinResponse, raw_data.json()).data def title_escape(text: str) -> str: @@ -136,8 +136,9 @@ class AkVersion(StatusChange): return "明日方舟游戏信息" async def get_status(self, _): - res_ver = await self.client.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version") - res_preanounce = await self.client.get( + client = await self.ctx.get_client() + res_ver = await client.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version") + res_preanounce = await client.get( "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json" ) res = res_ver.json() @@ -179,7 +180,8 @@ class MonsterSiren(NewMessage): return "明日方舟游戏信息" async def get_sub_list(self, _) -> list[RawPost]: - raw_data = await self.client.get("https://monster-siren.hypergryph.com/api/news") + client = await self.ctx.get_client() + raw_data = await client.get("https://monster-siren.hypergryph.com/api/news") return raw_data.json()["data"]["list"] def get_id(self, post: RawPost) -> Any: @@ -192,8 +194,9 @@ class MonsterSiren(NewMessage): return Category(3) async def parse(self, raw_post: RawPost) -> Post: + client = await self.ctx.get_client() url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}' - res = await self.client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}') + res = await client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}') raw_data = res.json() content = raw_data["data"]["content"] content = content.replace("

", "

\n") @@ -226,7 +229,8 @@ class TerraHistoricusComic(NewMessage): return "明日方舟游戏信息" async def get_sub_list(self, _) -> list[RawPost]: - raw_data = await self.client.get("https://terra-historicus.hypergryph.com/api/recentUpdate") + client = await self.ctx.get_client() + raw_data = await client.get("https://terra-historicus.hypergryph.com/api/recentUpdate") return raw_data.json()["data"] def get_id(self, post: RawPost) -> Any: diff --git a/nonebot_bison/platform/bilibili.py b/nonebot_bison/platform/bilibili.py index acd67d6..053a591 100644 --- a/nonebot_bison/platform/bilibili.py +++ b/nonebot_bison/platform/bilibili.py @@ -1,6 +1,5 @@ import re import json -from abc import ABC from copy import deepcopy from enum import Enum, unique from typing_extensions import Self @@ -13,6 +12,7 @@ from pydantic import Field, BaseModel from nonebot.compat import PYDANTIC_V2, ConfigDict, type_validate_json, type_validate_python from nonebot_bison.compat import model_rebuild +from nonebot_bison.utils.scheduler_config import ClientManager from ..post import Post from ..types import Tag, Target, RawPost, ApiError, Category @@ -104,7 +104,7 @@ model_rebuild_recurse(UserAPI) model_rebuild_recurse(PostAPI) -class BilibiliClient: +class BilibiliClient(ClientManager): _client: AsyncClient _refresh_time: datetime cookie_expire_time = timedelta(hours=5) @@ -124,37 +124,27 @@ class BilibiliClient: if datetime.now() - self._refresh_time > self.cookie_expire_time: await self._init_session() - async def get_client(self) -> AsyncClient: + async def get_client(self, target: Target | None) -> AsyncClient: + await self._refresh_client() + return self._client + + async def get_query_name_client(self) -> AsyncClient: await self._refresh_client() return self._client -bilibili_client = BilibiliClient() - - -class BaseSchedConf(ABC, SchedulerConfig): - schedule_type = "interval" - bilibili_client: BilibiliClient - - def __init__(self): - super().__init__() - self.bilibili_client = bilibili_client - - async def get_client(self, _: Target) -> AsyncClient: - return await self.bilibili_client.get_client() - - async def get_query_name_client(self) -> AsyncClient: - return await self.bilibili_client.get_client() - - -class BilibiliSchedConf(BaseSchedConf): +class BilibiliSchedConf(SchedulerConfig): name = "bilibili.com" + schedule_type = "interval" schedule_setting = {"seconds": 10} + client_man = BilibiliClient -class BililiveSchedConf(BaseSchedConf): +class BililiveSchedConf(SchedulerConfig): name = "live.bilibili.com" + schedule_type = "interval" schedule_setting = {"seconds": 3} + client_man = BilibiliClient class Bilibili(NewMessage): @@ -198,8 +188,9 @@ class Bilibili(NewMessage): ) async def get_sub_list(self, target: Target) -> list[DynRawPost]: + client = await self.ctx.get_client() params = {"host_uid": target, "offset": 0, "need_top": 0} - res = await self.client.get( + res = await client.get( "https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history", params=params, timeout=4.0, @@ -428,8 +419,9 @@ class Bilibililive(StatusChange): ) async def batch_get_status(self, targets: list[Target]) -> list[Info]: + client = await self.ctx.get_client() # https://github.com/SocialSisterYi/bilibili-API-collect/blob/master/docs/live/info.md#批量查询直播间状态 - res = await self.client.get( + res = await client.get( "https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids", params={"uids[]": targets}, timeout=4.0, @@ -520,7 +512,8 @@ class BilibiliBangumi(StatusChange): ) async def get_status(self, target: Target): - res = await self.client.get( + client = await self.ctx.get_client() + res = await client.get( self._url, params={"media_id": target}, timeout=4.0, @@ -542,9 +535,8 @@ class BilibiliBangumi(StatusChange): return [] async def parse(self, raw_post: RawPost) -> Post: - detail_res = await self.client.get( - f'https://api.bilibili.com/pgc/view/web/season?season_id={raw_post["season_id"]}' - ) + client = await self.ctx.get_client() + detail_res = await client.get(f'https://api.bilibili.com/pgc/view/web/season?season_id={raw_post["season_id"]}') detail_dict = detail_res.json() lastest_episode = None for episode in detail_dict["result"]["episodes"][::-1]: diff --git a/nonebot_bison/platform/ff14.py b/nonebot_bison/platform/ff14.py index e050aae..667d456 100644 --- a/nonebot_bison/platform/ff14.py +++ b/nonebot_bison/platform/ff14.py @@ -24,7 +24,8 @@ class FF14(NewMessage): return "最终幻想XIV官方公告" async def get_sub_list(self, _) -> list[RawPost]: - raw_data = await self.client.get( + client = await self.ctx.get_client() + raw_data = await client.get( "https://cqnews.web.sdo.com/api/news/newsList?gameCode=ff&CategoryCode=5309,5310,5311,5312,5313&pageIndex=0&pageSize=5" ) return raw_data.json()["Data"] diff --git a/nonebot_bison/platform/ncm.py b/nonebot_bison/platform/ncm.py index 538518a..d20d13d 100644 --- a/nonebot_bison/platform/ncm.py +++ b/nonebot_bison/platform/ncm.py @@ -47,7 +47,8 @@ class NcmArtist(NewMessage): raise cls.ParseTargetException("正确格式:\n1. 歌手数字ID\n2. https://music.163.com/#/artist?id=xxxx") async def get_sub_list(self, target: Target) -> list[RawPost]: - res = await self.client.get( + client = await self.ctx.get_client() + res = await client.get( f"https://music.163.com/api/artist/albums/{target}", headers={"Referer": "https://music.163.com/"}, ) @@ -106,7 +107,8 @@ class NcmRadio(NewMessage): ) async def get_sub_list(self, target: Target) -> list[RawPost]: - res = await self.client.post( + client = await self.ctx.get_client() + res = await client.post( "http://music.163.com/api/dj/program/byradio", headers={"Referer": "https://music.163.com/"}, data={"radioId": target, "limit": 1000, "offset": 0}, diff --git a/nonebot_bison/platform/platform.py b/nonebot_bison/platform/platform.py index 2bf02f3..35975a9 100644 --- a/nonebot_bison/platform/platform.py +++ b/nonebot_bison/platform/platform.py @@ -92,7 +92,6 @@ class Platform(metaclass=PlatformABCMeta, base=True): platform_name: str parse_target_promot: str | None = None registry: list[type["Platform"]] - client: AsyncClient reverse_category: dict[str, Category] use_batch: bool = False # TODO: 限定可使用的theme名称 @@ -121,9 +120,8 @@ class Platform(metaclass=PlatformABCMeta, base=True): "actually function called" return await self.parse(raw_post) - def __init__(self, context: ProcessContext, client: AsyncClient): + def __init__(self, context: ProcessContext): super().__init__() - self.client = client self.ctx = context class ParseTargetException(Exception): @@ -225,8 +223,8 @@ class Platform(metaclass=PlatformABCMeta, base=True): class MessageProcess(Platform, abstract=True): "General message process fetch, parse, filter progress" - def __init__(self, ctx: ProcessContext, client: AsyncClient): - super().__init__(ctx, client) + def __init__(self, ctx: ProcessContext): + super().__init__(ctx) self.parse_cache: dict[Any, Post] = {} @abstractmethod @@ -463,11 +461,11 @@ def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]: if platform.scheduler != scheduler: raise RuntimeError(f"Platform scheduler for {platform_name} not fit") - def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient): - Platform.__init__(self, ctx, client) + def __init__(self: "NoTargetGroup", ctx: ProcessContext): + Platform.__init__(self, ctx) self.platform_obj_list = [] for platform_class in self.platform_list: - self.platform_obj_list.append(platform_class(ctx, client)) + self.platform_obj_list.append(platform_class(ctx)) def __str__(self: "NoTargetGroup") -> str: return "[" + " ".join(x.name for x in self.platform_list) + "]" diff --git a/nonebot_bison/platform/rss.py b/nonebot_bison/platform/rss.py index a7af592..a66dad9 100644 --- a/nonebot_bison/platform/rss.py +++ b/nonebot_bison/platform/rss.py @@ -46,7 +46,8 @@ class Rss(NewMessage): return post.id async def get_sub_list(self, target: Target) -> list[RawPost]: - res = await self.client.get(target, timeout=10.0) + client = await self.ctx.get_client() + res = await client.get(target, timeout=10.0) feed = feedparser.parse(res) entries = feed.entries for entry in entries: diff --git a/nonebot_bison/platform/weibo.py b/nonebot_bison/platform/weibo.py index 1b4d181..9a02b8d 100644 --- a/nonebot_bison/platform/weibo.py +++ b/nonebot_bison/platform/weibo.py @@ -78,8 +78,9 @@ class Weibo(NewMessage): raise cls.ParseTargetException(prompt="正确格式:\n1. 用户数字UID\n2. https://weibo.com/u/xxxx") async def get_sub_list(self, target: Target) -> list[RawPost]: + client = await self.ctx.get_client() params = {"containerid": "107603" + target} - res = await self.client.get("https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0) + res = await client.get("https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0) res_data = json.loads(res.text) if not res_data["ok"] and res_data["msg"] != "这里还没有内容": raise ApiError(res.request.url) @@ -149,7 +150,8 @@ class Weibo(NewMessage): async def _get_long_weibo(self, weibo_id: str) -> dict: try: - weibo_info = await self.client.get( + client = await self.ctx.get_client() + weibo_info = await client.get( "https://m.weibo.cn/statuses/show", params={"id": weibo_id}, headers=_HEADER, diff --git a/nonebot_bison/scheduler/scheduler.py b/nonebot_bison/scheduler/scheduler.py index b1fc530..76962f1 100644 --- a/nonebot_bison/scheduler/scheduler.py +++ b/nonebot_bison/scheduler/scheduler.py @@ -5,6 +5,8 @@ from nonebot.log import logger from nonebot_plugin_apscheduler import scheduler from nonebot_plugin_saa.utils.exceptions import NoBotFound +from nonebot_bison.utils.scheduler_config import ClientManager + from ..config import config from ..send import send_msgs from ..types import Target, SubUnit @@ -24,6 +26,7 @@ class Scheduler: schedulable_list: list[Schedulable] # for load weigth from db batch_api_target_cache: dict[str, dict[Target, list[Target]]] # platform_name -> (target -> [target]) batch_platform_name_targets_cache: dict[str, list[Target]] + client_mgr: ClientManager def __init__( self, @@ -36,6 +39,7 @@ class Scheduler: logger.error(f"scheduler config [{self.name}] not found, exiting") raise RuntimeError(f"{self.name} not found") self.scheduler_config = scheduler_config + self.client_mgr = scheduler_config.client_mgr() self.scheduler_config_obj = self.scheduler_config() self.schedulable_list = [] @@ -83,16 +87,14 @@ class Scheduler: return cur_max_schedulable async def exec_fetch(self): - context = ProcessContext() if not (schedulable := await self.get_next_schedulable()): return logger.trace(f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}") - client = await self.scheduler_config_obj.get_client(schedulable.target) - context.register_to_client(client) + context = ProcessContext(self.client_mgr) try: - platform_obj = platform_manager[schedulable.platform_name](context, client) + platform_obj = platform_manager[schedulable.platform_name](context) if schedulable.use_batch: batch_targets = self.batch_api_target_cache[schedulable.platform_name][schedulable.target] sub_units = [] diff --git a/nonebot_bison/theme/themes/basic/build.py b/nonebot_bison/theme/themes/basic/build.py index 93ed440..c801258 100644 --- a/nonebot_bison/theme/themes/basic/build.py +++ b/nonebot_bison/theme/themes/basic/build.py @@ -42,11 +42,12 @@ class BasicTheme(Theme): if urls: text += "\n".join(urls) + client = await post.platform.ctx.get_client_for_static() msgs: list[MessageSegmentFactory] = [Text(text)] if post.images: pics = post.images if is_pics_mergable(pics): - pics = await pic_merge(list(pics), post.platform.client) + pics = await pic_merge(list(pics), client) msgs.extend(map(Image, pics)) return msgs diff --git a/nonebot_bison/theme/themes/brief/build.py b/nonebot_bison/theme/themes/brief/build.py index 2e9bf63..612d41a 100644 --- a/nonebot_bison/theme/themes/brief/build.py +++ b/nonebot_bison/theme/themes/brief/build.py @@ -29,11 +29,12 @@ class BriefTheme(Theme): if urls: text += "\n".join(urls) + client = await post.platform.ctx.get_client_for_static() msgs: list[MessageSegmentFactory] = [Text(text)] if post.images: pics = post.images if is_pics_mergable(pics): - pics = await pic_merge(list(pics), post.platform.client) + pics = await pic_merge(list(pics), client) msgs.append(Image(pics[0])) return msgs diff --git a/nonebot_bison/theme/themes/ht2i/build.py b/nonebot_bison/theme/themes/ht2i/build.py index d03ce80..88c6d4a 100644 --- a/nonebot_bison/theme/themes/ht2i/build.py +++ b/nonebot_bison/theme/themes/ht2i/build.py @@ -54,9 +54,10 @@ class Ht2iTheme(Theme): msgs.append(Text("\n".join(urls))) if post.images: + client = await post.platform.ctx.get_client_for_static() pics = post.images if is_pics_mergable(pics): - pics = await pic_merge(list(pics), post.platform.client) + pics = await pic_merge(list(pics), client) msgs.extend(map(Image, pics)) return msgs diff --git a/nonebot_bison/utils/__init__.py b/nonebot_bison/utils/__init__.py index 29da939..64f6fc4 100644 --- a/nonebot_bison/utils/__init__.py +++ b/nonebot_bison/utils/__init__.py @@ -11,14 +11,16 @@ from nonebot_plugin_saa import Text, Image, MessageSegmentFactory from .http import http_client from .context import ProcessContext from ..plugin_config import plugin_config -from .scheduler_config import SchedulerConfig, scheduler from .image import pic_merge, text_to_image, is_pics_mergable, pic_url_to_image +from .scheduler_config import ClientManager, SchedulerConfig, DefaultClientManager, scheduler __all__ = [ "http_client", "Singleton", "parse_text", "ProcessContext", + "ClientManager", + "DefaultClientManager", "html_to_text", "SchedulerConfig", "scheduler", diff --git a/nonebot_bison/utils/context.py b/nonebot_bison/utils/context.py index 83403a0..7981370 100644 --- a/nonebot_bison/utils/context.py +++ b/nonebot_bison/utils/context.py @@ -2,19 +2,25 @@ from base64 import b64encode from httpx import Response, AsyncClient +from nonebot_bison.types import Target + +from .scheduler_config import ClientManager + class ProcessContext: reqs: list[Response] + _client_mgr: ClientManager - def __init__(self) -> None: + def __init__(self, client_mgr: ClientManager) -> None: self.reqs = [] + self._client_mgr = client_mgr - def log_response(self, resp: Response): + def _log_response(self, resp: Response): self.reqs.append(resp) - def register_to_client(self, client: AsyncClient): + def _register_to_client(self, client: AsyncClient): async def _log_to_ctx(r: Response): - self.log_response(r) + self._log_response(r) hooks = { "response": [_log_to_ctx], @@ -41,3 +47,16 @@ class ProcessContext: ) res.append(log_content) return res + + async def get_client(self, target: Target | None = None) -> AsyncClient: + client = await self._client_mgr.get_client(target) + self._register_to_client(client) + return client + + async def get_client_for_static(self) -> AsyncClient: + client = await self._client_mgr.get_client_for_static() + self._register_to_client(client) + return client + + async def refresh_client(self): + await self._client_mgr.refresh_client() diff --git a/nonebot_bison/utils/scheduler_config.py b/nonebot_bison/utils/scheduler_config.py index 293eae6..cc9fdfd 100644 --- a/nonebot_bison/utils/scheduler_config.py +++ b/nonebot_bison/utils/scheduler_config.py @@ -1,3 +1,4 @@ +from abc import ABC from typing import Literal from httpx import AsyncClient @@ -6,10 +7,32 @@ from ..types import Target from .http import http_client +class ClientManager(ABC): + async def get_client(self, target: Target | None) -> AsyncClient: ... + + async def get_client_for_static(self) -> AsyncClient: ... + + async def get_query_name_client(self) -> AsyncClient: ... + + async def refresh_client(self): ... + + +class DefaultClientManager(ClientManager): + async def get_client(self, target: Target | None) -> AsyncClient: + return http_client() + + async def get_client_for_static(self) -> AsyncClient: + return http_client() + + async def get_query_name_client(self) -> AsyncClient: + return http_client() + + class SchedulerConfig: schedule_type: Literal["date", "interval", "cron"] schedule_setting: dict name: str + client_mgr: type[ClientManager] = DefaultClientManager require_browser: bool = False def __str__(self): @@ -18,12 +41,6 @@ class SchedulerConfig: def __init__(self): self.default_http_client = http_client() - async def get_client(self, target: Target) -> AsyncClient: - return self.default_http_client - - async def get_query_name_client(self) -> AsyncClient: - return self.default_http_client - def scheduler(schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict) -> type[SchedulerConfig]: return type( @@ -32,5 +49,6 @@ def scheduler(schedule_type: Literal["date", "interval", "cron"], schedule_setti { "schedule_type": schedule_type, "schedule_setting": schedule_setting, + "client_mgr": ClientManager, }, ) diff --git a/tests/platforms/test_arknights.py b/tests/platforms/test_arknights.py index 12146f2..5d44819 100644 --- a/tests/platforms/test_arknights.py +++ b/tests/platforms/test_arknights.py @@ -2,8 +2,8 @@ from time import time import respx import pytest +from httpx import Response from nonebug.app import App -from httpx import Response, AsyncClient from nonebot.compat import model_dump, type_validate_python from .utils import get_file, get_json @@ -13,8 +13,9 @@ from .utils import get_file, get_json def arknights(app: App): from nonebot_bison.utils import ProcessContext from nonebot_bison.platform import platform_manager + from nonebot_bison.utils.scheduler_config import DefaultClientManager - return platform_manager["arknights"](ProcessContext(), AsyncClient()) + return platform_manager["arknights"](ProcessContext(DefaultClientManager())) @pytest.fixture(scope="module") @@ -44,9 +45,8 @@ def monster_siren_list_1(): @respx.mock async def test_url_parse(app: App): - from httpx import AsyncClient - from nonebot_bison.utils import ProcessContext + from nonebot_bison.utils.scheduler_config import DefaultClientManager from nonebot_bison.platform.arknights import Arknights, BulletinData, BulletinListItem, ArkBulletinResponse cid_router = respx.get("https://ak-webview.hypergryph.com/api/game/bulletin/1") @@ -93,7 +93,7 @@ async def test_url_parse(app: App): b4 = make_bulletin_obj("http://www.baidu.com") assert b4.jump_link == "http://www.baidu.com" - ark = Arknights(ProcessContext(), AsyncClient()) + ark = Arknights(ProcessContext(DefaultClientManager())) cid_router.mock(return_value=make_response(b1)) p1 = await ark.parse(make_bulletin_list_item_obj()) @@ -115,9 +115,10 @@ async def test_url_parse(app: App): @pytest.mark.asyncio() async def test_get_date_in_bulletin(app: App): from nonebot_bison.utils import ProcessContext + from nonebot_bison.utils.scheduler_config import DefaultClientManager from nonebot_bison.platform.arknights import Arknights, BulletinListItem - arknights = Arknights(ProcessContext(), AsyncClient()) + arknights = Arknights(ProcessContext(DefaultClientManager())) assert ( arknights.get_date( BulletinListItem( @@ -136,13 +137,14 @@ async def test_get_date_in_bulletin(app: App): @pytest.mark.asyncio() @respx.mock async def test_parse_with_breakline(app: App): - from nonebot_bison.utils import ProcessContext, http_client + from nonebot_bison.utils import ProcessContext + from nonebot_bison.utils.scheduler_config import DefaultClientManager from nonebot_bison.platform.arknights import Arknights, BulletinListItem detail = get_json("arknights-detail-805") detail["data"]["header"] = "" - arknights = Arknights(ProcessContext(), http_client()) + arknights = Arknights(ProcessContext(DefaultClientManager())) router = respx.get("https://ak-webview.hypergryph.com/api/game/bulletin/1") router.mock(return_value=Response(200, json=detail)) diff --git a/tests/platforms/test_bilibili.py b/tests/platforms/test_bilibili.py index c10327e..577377c 100644 --- a/tests/platforms/test_bilibili.py +++ b/tests/platforms/test_bilibili.py @@ -3,8 +3,8 @@ from datetime import datetime import respx import pytest +from httpx import Response from nonebug.app import App -from httpx import Response, AsyncClient from nonebot.compat import model_dump, type_validate_python from .utils import get_json @@ -25,8 +25,9 @@ if typing.TYPE_CHECKING: def bilibili(app: App) -> "Bilibili": from nonebot_bison.utils import ProcessContext from nonebot_bison.platform import platform_manager + from nonebot_bison.utils.scheduler_config import DefaultClientManager - return platform_manager["bilibili"](ProcessContext(), AsyncClient()) # type: ignore + return platform_manager["bilibili"](ProcessContext(DefaultClientManager())) # type: ignore @pytest.fixture() diff --git a/tests/platforms/test_bilibili_bangumi.py b/tests/platforms/test_bilibili_bangumi.py index 0df1701..262926b 100644 --- a/tests/platforms/test_bilibili_bangumi.py +++ b/tests/platforms/test_bilibili_bangumi.py @@ -2,8 +2,8 @@ import typing import respx import pytest +from httpx import Response from nonebug.app import App -from httpx import Response, AsyncClient from .utils import get_json @@ -15,8 +15,9 @@ if typing.TYPE_CHECKING: def bili_bangumi(app: App): from nonebot_bison.utils import ProcessContext from nonebot_bison.platform import platform_manager + from nonebot_bison.utils.scheduler_config import DefaultClientManager - return platform_manager["bilibili-bangumi"](ProcessContext(), AsyncClient()) + return platform_manager["bilibili-bangumi"](ProcessContext(DefaultClientManager())) async def test_parse_target(bili_bangumi: "BilibiliBangumi"): diff --git a/tests/platforms/test_bilibili_live.py b/tests/platforms/test_bilibili_live.py index f6208cc..604830b 100644 --- a/tests/platforms/test_bilibili_live.py +++ b/tests/platforms/test_bilibili_live.py @@ -3,8 +3,8 @@ from typing import TYPE_CHECKING import respx import pytest +from httpx import Response from nonebug.app import App -from httpx import Response, AsyncClient from .utils import get_json @@ -16,8 +16,9 @@ if TYPE_CHECKING: def bili_live(app: App): from nonebot_bison.utils import ProcessContext from nonebot_bison.platform import platform_manager + from nonebot_bison.platform.bilibili import BilibiliClient - return platform_manager["bilibili-live"](ProcessContext(), AsyncClient()) + return platform_manager["bilibili-live"](ProcessContext(BilibiliClient())) @pytest.fixture() @@ -30,27 +31,6 @@ def dummy_only_open_user_subinfo(app: App): return UserSubInfo(user=user, categories=[1], tags=[]) -@pytest.mark.asyncio -async def test_http_client_equal(app: App): - from nonebot_bison.types import Target - from nonebot_bison.utils import ProcessContext - from nonebot_bison.platform import platform_manager - - empty_target = Target("0") - - bilibili = platform_manager["bilibili"](ProcessContext(), AsyncClient()) - bilibili_live = platform_manager["bilibili-live"](ProcessContext(), AsyncClient()) - - bilibili_scheduler = bilibili.scheduler() - bilibili_live_scheduler = bilibili_live.scheduler() - - assert await bilibili_scheduler.get_client(empty_target) == await bilibili_live_scheduler.get_client(empty_target) - assert await bilibili_live_scheduler.get_client(empty_target) != bilibili_live_scheduler.default_http_client - - assert await bilibili_scheduler.get_query_name_client() == await bilibili_live_scheduler.get_query_name_client() - assert await bilibili_scheduler.get_query_name_client() != bilibili_live_scheduler.default_http_client - - @pytest.mark.asyncio @respx.mock async def test_fetch_bililive_no_room(bili_live, dummy_only_open_user_subinfo): diff --git a/tests/platforms/test_ff14.py b/tests/platforms/test_ff14.py index 485226c..eb22d7e 100644 --- a/tests/platforms/test_ff14.py +++ b/tests/platforms/test_ff14.py @@ -1,7 +1,7 @@ import respx import pytest +from httpx import Response from nonebug.app import App -from httpx import Response, AsyncClient from .utils import get_json @@ -10,8 +10,9 @@ from .utils import get_json def ff14(app: App): from nonebot_bison.utils import ProcessContext from nonebot_bison.platform import platform_manager + from nonebot_bison.utils.scheduler_config import DefaultClientManager - return platform_manager["ff14"](ProcessContext(), AsyncClient()) + return platform_manager["ff14"](ProcessContext(DefaultClientManager())) @pytest.fixture(scope="module") diff --git a/tests/platforms/test_ncm_artist.py b/tests/platforms/test_ncm_artist.py index 0550d2e..f6adf05 100644 --- a/tests/platforms/test_ncm_artist.py +++ b/tests/platforms/test_ncm_artist.py @@ -3,8 +3,8 @@ import typing import respx import pytest +from httpx import Response from nonebug.app import App -from httpx import Response, AsyncClient from .utils import get_json @@ -16,8 +16,9 @@ if typing.TYPE_CHECKING: def ncm_artist(app: App): from nonebot_bison.utils import ProcessContext from nonebot_bison.platform import platform_manager + from nonebot_bison.utils.scheduler_config import DefaultClientManager - return platform_manager["ncm-artist"](ProcessContext(), AsyncClient()) + return platform_manager["ncm-artist"](ProcessContext(DefaultClientManager())) @pytest.fixture(scope="module") diff --git a/tests/platforms/test_ncm_radio.py b/tests/platforms/test_ncm_radio.py index 461b57b..98c213a 100644 --- a/tests/platforms/test_ncm_radio.py +++ b/tests/platforms/test_ncm_radio.py @@ -3,8 +3,8 @@ import typing import respx import pytest +from httpx import Response from nonebug.app import App -from httpx import Response, AsyncClient from .utils import get_json @@ -16,8 +16,9 @@ if typing.TYPE_CHECKING: def ncm_radio(app: App): from nonebot_bison.utils import ProcessContext from nonebot_bison.platform import platform_manager + from nonebot_bison.utils.scheduler_config import DefaultClientManager - return platform_manager["ncm-radio"](ProcessContext(), AsyncClient()) + return platform_manager["ncm-radio"](ProcessContext(DefaultClientManager())) @pytest.fixture(scope="module") diff --git a/tests/platforms/test_platform.py b/tests/platforms/test_platform.py index 7a5714c..2f02c66 100644 --- a/tests/platforms/test_platform.py +++ b/tests/platforms/test_platform.py @@ -3,7 +3,6 @@ from typing import Any import pytest from nonebug.app import App -from httpx import AsyncClient now = time() passed = now - 3 * 60 * 60 @@ -326,12 +325,13 @@ def mock_status_change(app: App): async def test_new_message_target_without_cats_tags(mock_platform_without_cats_tags, user_info_factory): from nonebot_bison.utils import ProcessContext from nonebot_bison.types import Target, SubUnit + from nonebot_bison.utils.scheduler_config import DefaultClientManager - res1 = await mock_platform_without_cats_tags(ProcessContext(), AsyncClient()).fetch_new_post( + res1 = await mock_platform_without_cats_tags(ProcessContext(DefaultClientManager())).fetch_new_post( SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) ) assert len(res1) == 0 - res2 = await mock_platform_without_cats_tags(ProcessContext(), AsyncClient()).fetch_new_post( + res2 = await mock_platform_without_cats_tags(ProcessContext(DefaultClientManager())).fetch_new_post( SubUnit(Target("dummy"), [user_info_factory([], [])]), ) assert len(res2) == 1 @@ -347,12 +347,13 @@ async def test_new_message_target_without_cats_tags(mock_platform_without_cats_t async def test_new_message_target(mock_platform, user_info_factory): from nonebot_bison.utils import ProcessContext from nonebot_bison.types import Target, SubUnit + from nonebot_bison.utils.scheduler_config import DefaultClientManager - res1 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post( + res1 = await mock_platform(ProcessContext(DefaultClientManager())).fetch_new_post( SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) ) assert len(res1) == 0 - res2 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post( + res2 = await mock_platform(ProcessContext(DefaultClientManager())).fetch_new_post( SubUnit( Target("dummy"), [ @@ -382,12 +383,13 @@ async def test_new_message_target(mock_platform, user_info_factory): async def test_new_message_no_target(mock_platform_no_target, user_info_factory): from nonebot_bison.utils import ProcessContext from nonebot_bison.types import Target, SubUnit + from nonebot_bison.utils.scheduler_config import DefaultClientManager - res1 = await mock_platform_no_target(ProcessContext(), AsyncClient()).fetch_new_post( + res1 = await mock_platform_no_target(ProcessContext(DefaultClientManager())).fetch_new_post( SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) ) assert len(res1) == 0 - res2 = await mock_platform_no_target(ProcessContext(), AsyncClient()).fetch_new_post( + res2 = await mock_platform_no_target(ProcessContext(DefaultClientManager())).fetch_new_post( SubUnit( Target("dummy"), [ @@ -411,7 +413,7 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory) assert "p3" in id_set_1 assert "p2" in id_set_2 assert "p2" in id_set_3 - res3 = await mock_platform_no_target(ProcessContext(), AsyncClient()).fetch_new_post( + res3 = await mock_platform_no_target(ProcessContext(DefaultClientManager())).fetch_new_post( SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) ) assert len(res3) == 0 @@ -421,19 +423,20 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory) async def test_status_change(mock_status_change, user_info_factory): from nonebot_bison.utils import ProcessContext from nonebot_bison.types import Target, SubUnit + from nonebot_bison.utils.scheduler_config import DefaultClientManager - res1 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( + res1 = await mock_status_change(ProcessContext(DefaultClientManager())).fetch_new_post( SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) ) assert len(res1) == 0 - res2 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( + res2 = await mock_status_change(ProcessContext(DefaultClientManager())).fetch_new_post( SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) ) assert len(res2) == 1 posts = res2[0][1] assert len(posts) == 1 assert posts[0].content == "on" - res3 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( + res3 = await mock_status_change(ProcessContext(DefaultClientManager())).fetch_new_post( SubUnit( Target("dummy"), [ @@ -446,7 +449,7 @@ async def test_status_change(mock_status_change, user_info_factory): assert len(res3[0][1]) == 1 assert res3[0][1][0].content == "off" assert len(res3[1][1]) == 0 - res4 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( + res4 = await mock_status_change(ProcessContext(DefaultClientManager())).fetch_new_post( SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) ) assert len(res4) == 0 @@ -459,14 +462,15 @@ async def test_group( mock_platform_no_target_2, user_info_factory, ): + from nonebot_bison.utils import ProcessContext from nonebot_bison.types import Target, SubUnit - from nonebot_bison.utils import ProcessContext, http_client from nonebot_bison.platform.platform import make_no_target_group + from nonebot_bison.utils.scheduler_config import DefaultClientManager dummy = Target("dummy") group_platform_class = make_no_target_group([mock_platform_no_target, mock_platform_no_target_2]) - group_platform = group_platform_class(ProcessContext(), http_client()) + group_platform = group_platform_class(ProcessContext(DefaultClientManager())) res1 = await group_platform.fetch_new_post(SubUnit(dummy, [user_info_factory([1, 4], [])])) assert len(res1) == 0 res2 = await group_platform.fetch_new_post(SubUnit(dummy, [user_info_factory([1, 4], [])])) @@ -487,6 +491,7 @@ async def test_batch_fetch_new_message(app: App): from nonebot_bison.platform.platform import NewMessage from nonebot_bison.utils.context import ProcessContext from nonebot_bison.types import Target, RawPost, SubUnit, UserSubInfo + from nonebot_bison.utils.scheduler_config import DefaultClientManager class BatchNewMessage(NewMessage): platform_name = "mock_platform" @@ -538,7 +543,7 @@ async def test_batch_fetch_new_message(app: App): user1 = UserSubInfo(TargetQQGroup(group_id=123), [1, 2, 3], []) user2 = UserSubInfo(TargetQQGroup(group_id=234), [1, 2, 3], []) - platform_obj = BatchNewMessage(ProcessContext(), None) # type:ignore + platform_obj = BatchNewMessage(ProcessContext(DefaultClientManager())) # type:ignore res1 = await platform_obj.batch_fetch_new_post( [ @@ -572,6 +577,7 @@ async def test_batch_fetch_compare_status(app: App): from nonebot_bison.post import Post from nonebot_bison.utils.context import ProcessContext from nonebot_bison.platform.platform import StatusChange + from nonebot_bison.utils.scheduler_config import DefaultClientManager from nonebot_bison.types import Target, RawPost, SubUnit, Category, UserSubInfo class BatchStatusChange(StatusChange): @@ -612,7 +618,7 @@ async def test_batch_fetch_compare_status(app: App): def get_category(self, raw_post): return raw_post["cat"] - batch_status_change = BatchStatusChange(ProcessContext(), None) # type: ignore + batch_status_change = BatchStatusChange(ProcessContext(DefaultClientManager())) user1 = UserSubInfo(TargetQQGroup(group_id=123), [1, 2, 3], []) user2 = UserSubInfo(TargetQQGroup(group_id=234), [1, 2, 3], []) diff --git a/tests/platforms/test_platform_tag_filter.py b/tests/platforms/test_platform_tag_filter.py index 0b99a3a..dbb4b5d 100644 --- a/tests/platforms/test_platform_tag_filter.py +++ b/tests/platforms/test_platform_tag_filter.py @@ -1,6 +1,5 @@ import pytest from nonebug.app import App -from httpx import AsyncClient from .utils import get_json @@ -15,8 +14,9 @@ def test_cases(): async def test_filter_user_custom_tag(app: App, test_cases): from nonebot_bison.utils import ProcessContext from nonebot_bison.platform import platform_manager + from nonebot_bison.utils.scheduler_config import DefaultClientManager - bilibili = platform_manager["bilibili"](ProcessContext(), AsyncClient()) + bilibili = platform_manager["bilibili"](ProcessContext(DefaultClientManager())) for case in test_cases: res = bilibili.is_banned_post(**case["case"]) assert res == case["result"] @@ -27,8 +27,9 @@ async def test_filter_user_custom_tag(app: App, test_cases): async def test_tag_separator(app: App): from nonebot_bison.utils import ProcessContext from nonebot_bison.platform import platform_manager + from nonebot_bison.utils.scheduler_config import DefaultClientManager - bilibili = platform_manager["bilibili"](ProcessContext(), AsyncClient()) + bilibili = platform_manager["bilibili"](ProcessContext(DefaultClientManager())) tags = ["~111", "222", "333", "~444", "555"] res = bilibili.tag_separator(tags) assert res[0] == ["222", "333", "555"] diff --git a/tests/platforms/test_rss.py b/tests/platforms/test_rss.py index 7821088..99ac6b4 100644 --- a/tests/platforms/test_rss.py +++ b/tests/platforms/test_rss.py @@ -5,8 +5,8 @@ import xml.etree.ElementTree as ET import pytz import respx import pytest +from httpx import Response from nonebug.app import App -from httpx import Response, AsyncClient from .utils import get_file @@ -36,8 +36,9 @@ def user_info_factory(app: App, dummy_user): def rss(app: App): from nonebot_bison.utils import ProcessContext from nonebot_bison.platform import platform_manager + from nonebot_bison.utils.scheduler_config import DefaultClientManager - return platform_manager["rss"](ProcessContext(), AsyncClient()) + return platform_manager["rss"](ProcessContext(DefaultClientManager())) @pytest.fixture() diff --git a/tests/platforms/test_weibo.py b/tests/platforms/test_weibo.py index 16042c2..c2b5ec0 100644 --- a/tests/platforms/test_weibo.py +++ b/tests/platforms/test_weibo.py @@ -20,8 +20,9 @@ image_cdn_router = respx.route(host__regex=r"wx\d.sinaimg.cn", path__startswith= def weibo(app: App): from nonebot_bison.utils import ProcessContext from nonebot_bison.platform import platform_manager + from nonebot_bison.utils.scheduler_config import DefaultClientManager - return platform_manager["weibo"](ProcessContext(), AsyncClient()) + return platform_manager["weibo"](ProcessContext(DefaultClientManager())) @pytest.fixture(scope="module") diff --git a/tests/post/test_generate.py b/tests/post/test_generate.py index 0b00460..b0ff349 100644 --- a/tests/post/test_generate.py +++ b/tests/post/test_generate.py @@ -3,7 +3,6 @@ from typing import Any import pytest from nonebug.app import App -from httpx import AsyncClient now = time() passed = now - 3 * 60 * 60 @@ -173,8 +172,9 @@ async def test_generate_msg(mock_platform): from nonebot_bison.post import Post from nonebot_bison.utils import ProcessContext from nonebot_bison.plugin_config import plugin_config + from nonebot_bison.utils.scheduler_config import DefaultClientManager - post: Post = await mock_platform(ProcessContext(), AsyncClient()).parse(raw_post_list_1[0]) + post: Post = await mock_platform(ProcessContext(DefaultClientManager())).parse(raw_post_list_1[0]) assert post.platform.default_theme == "basic" res = await post.generate() assert len(res) == 1 @@ -203,10 +203,11 @@ async def test_msg_segments_convert(mock_platform): from nonebot_bison.post import Post from nonebot_bison.utils import ProcessContext from nonebot_bison.plugin_config import plugin_config + from nonebot_bison.utils.scheduler_config import DefaultClientManager plugin_config.bison_use_pic = True - post: Post = await mock_platform(ProcessContext(), AsyncClient()).parse(raw_post_list_1[0]) + post: Post = await mock_platform(ProcessContext(DefaultClientManager())).parse(raw_post_list_1[0]) assert post.platform.default_theme == "basic" res = await post.generate_messages() assert len(res) == 1 diff --git a/tests/scheduler/test_scheduler.py b/tests/scheduler/test_scheduler.py index b51401d..2850cde 100644 --- a/tests/scheduler/test_scheduler.py +++ b/tests/scheduler/test_scheduler.py @@ -3,7 +3,6 @@ from datetime import time from unittest.mock import AsyncMock from nonebug import App -from httpx import AsyncClient from pytest_mock import MockerFixture if typing.TYPE_CHECKING: @@ -61,11 +60,12 @@ async def test_scheduler_batch_api(init_scheduler, mocker: MockerFixture): from nonebot_bison.types import Target as T_Target from nonebot_bison.scheduler.manager import init_scheduler from nonebot_bison.platform.bilibili import BililiveSchedConf + from nonebot_bison.utils.scheduler_config import DefaultClientManager await config.add_subscribe(TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili-live", [], []) await config.add_subscribe(TargetQQGroup(group_id=123), T_Target("t2"), "target2", "bilibili-live", [], []) - mocker.patch.object(BililiveSchedConf, "get_client", return_value=AsyncClient()) + mocker.patch.object(BililiveSchedConf, "client_man", DefaultClientManager) await init_scheduler() diff --git a/tests/test_context.py b/tests/test_context.py index 26910d7..ac1f2c8 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -6,13 +6,14 @@ from nonebug.app import App @respx.mock async def test_http_error(app: App): from nonebot_bison.utils import ProcessContext, http_client + from nonebot_bison.utils.scheduler_config import DefaultClientManager example_route = respx.get("https://example.com") example_route.mock(httpx.Response(403, json={"error": "gg"})) - ctx = ProcessContext() + ctx = ProcessContext(DefaultClientManager()) async with http_client() as client: - ctx.register_to_client(client) + ctx._register_to_client(client) await client.get("https://example.com") assert ctx.gen_req_records() == [ diff --git a/tests/theme/test_themes.py b/tests/theme/test_themes.py index b12283b..40fd925 100644 --- a/tests/theme/test_themes.py +++ b/tests/theme/test_themes.py @@ -5,7 +5,6 @@ from inspect import cleandoc import pytest from flaky import flaky from nonebug import App -from httpx import AsyncClient now = time() passed = now - 3 * 60 * 60 @@ -69,9 +68,10 @@ def mock_platform(app: App): def mock_post(app: App, mock_platform): from nonebot_bison.post import Post from nonebot_bison.utils import ProcessContext + from nonebot_bison.utils.scheduler_config import DefaultClientManager return Post( - m := mock_platform(ProcessContext(), AsyncClient()), + m := mock_platform(ProcessContext(DefaultClientManager())), "text", title="title", images=["http://t.tt/1.jpg"],