♻️ refactor client of scheduler_config

This commit is contained in:
felinae98 2024-05-28 20:59:39 +08:00
parent c21b24b385
commit 2fc11a9653
30 changed files with 185 additions and 143 deletions

View File

@ -7,6 +7,6 @@ async def check_sub_target(platform_name: str, target: Target):
platform = platform_manager[platform_name] platform = platform_manager[platform_name]
scheduler_conf_class = platform.scheduler scheduler_conf_class = platform.scheduler
scheduler = scheduler_dict[scheduler_conf_class] scheduler = scheduler_dict[scheduler_conf_class]
client = await scheduler.scheduler_config_obj.get_query_name_client() client = await scheduler.client_mgr.get_query_name_client()
return await platform_manager[platform_name].get_target_name(client, target) return await platform_manager[platform_name].get_target_name(client, target)

View File

@ -74,7 +74,8 @@ class Arknights(NewMessage):
return "明日方舟游戏信息" return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[BulletinListItem]: async def get_sub_list(self, _) -> list[BulletinListItem]:
raw_data = await self.client.get("https://ak-webview.hypergryph.com/api/game/bulletinList?target=IOS") client = await self.ctx.get_client()
raw_data = await client.get("https://ak-webview.hypergryph.com/api/game/bulletinList?target=IOS")
return type_validate_python(ArkBulletinListResponse, raw_data.json()).data.list return type_validate_python(ArkBulletinListResponse, raw_data.json()).data.list
def get_id(self, post: BulletinListItem) -> Any: def get_id(self, post: BulletinListItem) -> Any:
@ -91,9 +92,8 @@ class Arknights(NewMessage):
return Category(1) return Category(1)
async def parse(self, raw_post: BulletinListItem) -> Post: async def parse(self, raw_post: BulletinListItem) -> Post:
raw_data = await self.client.get( client = await self.ctx.get_client()
f"https://ak-webview.hypergryph.com/api/game/bulletin/{self.get_id(post=raw_post)}" raw_data = await client.get(f"https://ak-webview.hypergryph.com/api/game/bulletin/{self.get_id(post=raw_post)}")
)
data = type_validate_python(ArkBulletinResponse, raw_data.json()).data data = type_validate_python(ArkBulletinResponse, raw_data.json()).data
def title_escape(text: str) -> str: def title_escape(text: str) -> str:
@ -136,8 +136,9 @@ class AkVersion(StatusChange):
return "明日方舟游戏信息" return "明日方舟游戏信息"
async def get_status(self, _): async def get_status(self, _):
res_ver = await self.client.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version") client = await self.ctx.get_client()
res_preanounce = await self.client.get( res_ver = await client.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version")
res_preanounce = await client.get(
"https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json" "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json"
) )
res = res_ver.json() res = res_ver.json()
@ -179,7 +180,8 @@ class MonsterSiren(NewMessage):
return "明日方舟游戏信息" return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]: async def get_sub_list(self, _) -> list[RawPost]:
raw_data = await self.client.get("https://monster-siren.hypergryph.com/api/news") client = await self.ctx.get_client()
raw_data = await client.get("https://monster-siren.hypergryph.com/api/news")
return raw_data.json()["data"]["list"] return raw_data.json()["data"]["list"]
def get_id(self, post: RawPost) -> Any: def get_id(self, post: RawPost) -> Any:
@ -192,8 +194,9 @@ class MonsterSiren(NewMessage):
return Category(3) return Category(3)
async def parse(self, raw_post: RawPost) -> Post: async def parse(self, raw_post: RawPost) -> Post:
client = await self.ctx.get_client()
url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}' url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}'
res = await self.client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}') res = await client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}')
raw_data = res.json() raw_data = res.json()
content = raw_data["data"]["content"] content = raw_data["data"]["content"]
content = content.replace("</p>", "</p>\n") content = content.replace("</p>", "</p>\n")
@ -226,7 +229,8 @@ class TerraHistoricusComic(NewMessage):
return "明日方舟游戏信息" return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]: async def get_sub_list(self, _) -> list[RawPost]:
raw_data = await self.client.get("https://terra-historicus.hypergryph.com/api/recentUpdate") client = await self.ctx.get_client()
raw_data = await client.get("https://terra-historicus.hypergryph.com/api/recentUpdate")
return raw_data.json()["data"] return raw_data.json()["data"]
def get_id(self, post: RawPost) -> Any: def get_id(self, post: RawPost) -> Any:

View File

@ -1,6 +1,5 @@
import re import re
import json import json
from abc import ABC
from copy import deepcopy from copy import deepcopy
from enum import Enum, unique from enum import Enum, unique
from typing_extensions import Self from typing_extensions import Self
@ -13,6 +12,7 @@ from pydantic import Field, BaseModel
from nonebot.compat import PYDANTIC_V2, ConfigDict, type_validate_json, type_validate_python from nonebot.compat import PYDANTIC_V2, ConfigDict, type_validate_json, type_validate_python
from nonebot_bison.compat import model_rebuild from nonebot_bison.compat import model_rebuild
from nonebot_bison.utils.scheduler_config import ClientManager
from ..post import Post from ..post import Post
from ..types import Tag, Target, RawPost, ApiError, Category from ..types import Tag, Target, RawPost, ApiError, Category
@ -104,7 +104,7 @@ model_rebuild_recurse(UserAPI)
model_rebuild_recurse(PostAPI) model_rebuild_recurse(PostAPI)
class BilibiliClient: class BilibiliClient(ClientManager):
_client: AsyncClient _client: AsyncClient
_refresh_time: datetime _refresh_time: datetime
cookie_expire_time = timedelta(hours=5) cookie_expire_time = timedelta(hours=5)
@ -124,37 +124,27 @@ class BilibiliClient:
if datetime.now() - self._refresh_time > self.cookie_expire_time: if datetime.now() - self._refresh_time > self.cookie_expire_time:
await self._init_session() await self._init_session()
async def get_client(self) -> AsyncClient: async def get_client(self, target: Target | None) -> AsyncClient:
await self._refresh_client()
return self._client
async def get_query_name_client(self) -> AsyncClient:
await self._refresh_client() await self._refresh_client()
return self._client return self._client
bilibili_client = BilibiliClient() class BilibiliSchedConf(SchedulerConfig):
class BaseSchedConf(ABC, SchedulerConfig):
schedule_type = "interval"
bilibili_client: BilibiliClient
def __init__(self):
super().__init__()
self.bilibili_client = bilibili_client
async def get_client(self, _: Target) -> AsyncClient:
return await self.bilibili_client.get_client()
async def get_query_name_client(self) -> AsyncClient:
return await self.bilibili_client.get_client()
class BilibiliSchedConf(BaseSchedConf):
name = "bilibili.com" name = "bilibili.com"
schedule_type = "interval"
schedule_setting = {"seconds": 10} schedule_setting = {"seconds": 10}
client_man = BilibiliClient
class BililiveSchedConf(BaseSchedConf): class BililiveSchedConf(SchedulerConfig):
name = "live.bilibili.com" name = "live.bilibili.com"
schedule_type = "interval"
schedule_setting = {"seconds": 3} schedule_setting = {"seconds": 3}
client_man = BilibiliClient
class Bilibili(NewMessage): class Bilibili(NewMessage):
@ -198,8 +188,9 @@ class Bilibili(NewMessage):
) )
async def get_sub_list(self, target: Target) -> list[DynRawPost]: async def get_sub_list(self, target: Target) -> list[DynRawPost]:
client = await self.ctx.get_client()
params = {"host_uid": target, "offset": 0, "need_top": 0} params = {"host_uid": target, "offset": 0, "need_top": 0}
res = await self.client.get( res = await client.get(
"https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history", "https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history",
params=params, params=params,
timeout=4.0, timeout=4.0,
@ -428,8 +419,9 @@ class Bilibililive(StatusChange):
) )
async def batch_get_status(self, targets: list[Target]) -> list[Info]: async def batch_get_status(self, targets: list[Target]) -> list[Info]:
client = await self.ctx.get_client()
# https://github.com/SocialSisterYi/bilibili-API-collect/blob/master/docs/live/info.md#批量查询直播间状态 # https://github.com/SocialSisterYi/bilibili-API-collect/blob/master/docs/live/info.md#批量查询直播间状态
res = await self.client.get( res = await client.get(
"https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids", "https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids",
params={"uids[]": targets}, params={"uids[]": targets},
timeout=4.0, timeout=4.0,
@ -520,7 +512,8 @@ class BilibiliBangumi(StatusChange):
) )
async def get_status(self, target: Target): async def get_status(self, target: Target):
res = await self.client.get( client = await self.ctx.get_client()
res = await client.get(
self._url, self._url,
params={"media_id": target}, params={"media_id": target},
timeout=4.0, timeout=4.0,
@ -542,9 +535,8 @@ class BilibiliBangumi(StatusChange):
return [] return []
async def parse(self, raw_post: RawPost) -> Post: async def parse(self, raw_post: RawPost) -> Post:
detail_res = await self.client.get( client = await self.ctx.get_client()
f'https://api.bilibili.com/pgc/view/web/season?season_id={raw_post["season_id"]}' detail_res = await client.get(f'https://api.bilibili.com/pgc/view/web/season?season_id={raw_post["season_id"]}')
)
detail_dict = detail_res.json() detail_dict = detail_res.json()
lastest_episode = None lastest_episode = None
for episode in detail_dict["result"]["episodes"][::-1]: for episode in detail_dict["result"]["episodes"][::-1]:

View File

@ -24,7 +24,8 @@ class FF14(NewMessage):
return "最终幻想XIV官方公告" return "最终幻想XIV官方公告"
async def get_sub_list(self, _) -> list[RawPost]: async def get_sub_list(self, _) -> list[RawPost]:
raw_data = await self.client.get( client = await self.ctx.get_client()
raw_data = await client.get(
"https://cqnews.web.sdo.com/api/news/newsList?gameCode=ff&CategoryCode=5309,5310,5311,5312,5313&pageIndex=0&pageSize=5" "https://cqnews.web.sdo.com/api/news/newsList?gameCode=ff&CategoryCode=5309,5310,5311,5312,5313&pageIndex=0&pageSize=5"
) )
return raw_data.json()["Data"] return raw_data.json()["Data"]

View File

@ -47,7 +47,8 @@ class NcmArtist(NewMessage):
raise cls.ParseTargetException("正确格式:\n1. 歌手数字ID\n2. https://music.163.com/#/artist?id=xxxx") raise cls.ParseTargetException("正确格式:\n1. 歌手数字ID\n2. https://music.163.com/#/artist?id=xxxx")
async def get_sub_list(self, target: Target) -> list[RawPost]: async def get_sub_list(self, target: Target) -> list[RawPost]:
res = await self.client.get( client = await self.ctx.get_client()
res = await client.get(
f"https://music.163.com/api/artist/albums/{target}", f"https://music.163.com/api/artist/albums/{target}",
headers={"Referer": "https://music.163.com/"}, headers={"Referer": "https://music.163.com/"},
) )
@ -106,7 +107,8 @@ class NcmRadio(NewMessage):
) )
async def get_sub_list(self, target: Target) -> list[RawPost]: async def get_sub_list(self, target: Target) -> list[RawPost]:
res = await self.client.post( client = await self.ctx.get_client()
res = await client.post(
"http://music.163.com/api/dj/program/byradio", "http://music.163.com/api/dj/program/byradio",
headers={"Referer": "https://music.163.com/"}, headers={"Referer": "https://music.163.com/"},
data={"radioId": target, "limit": 1000, "offset": 0}, data={"radioId": target, "limit": 1000, "offset": 0},

View File

@ -92,7 +92,6 @@ class Platform(metaclass=PlatformABCMeta, base=True):
platform_name: str platform_name: str
parse_target_promot: str | None = None parse_target_promot: str | None = None
registry: list[type["Platform"]] registry: list[type["Platform"]]
client: AsyncClient
reverse_category: dict[str, Category] reverse_category: dict[str, Category]
use_batch: bool = False use_batch: bool = False
# TODO: 限定可使用的theme名称 # TODO: 限定可使用的theme名称
@ -121,9 +120,8 @@ class Platform(metaclass=PlatformABCMeta, base=True):
"actually function called" "actually function called"
return await self.parse(raw_post) return await self.parse(raw_post)
def __init__(self, context: ProcessContext, client: AsyncClient): def __init__(self, context: ProcessContext):
super().__init__() super().__init__()
self.client = client
self.ctx = context self.ctx = context
class ParseTargetException(Exception): class ParseTargetException(Exception):
@ -225,8 +223,8 @@ class Platform(metaclass=PlatformABCMeta, base=True):
class MessageProcess(Platform, abstract=True): class MessageProcess(Platform, abstract=True):
"General message process fetch, parse, filter progress" "General message process fetch, parse, filter progress"
def __init__(self, ctx: ProcessContext, client: AsyncClient): def __init__(self, ctx: ProcessContext):
super().__init__(ctx, client) super().__init__(ctx)
self.parse_cache: dict[Any, Post] = {} self.parse_cache: dict[Any, Post] = {}
@abstractmethod @abstractmethod
@ -463,11 +461,11 @@ def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]:
if platform.scheduler != scheduler: if platform.scheduler != scheduler:
raise RuntimeError(f"Platform scheduler for {platform_name} not fit") raise RuntimeError(f"Platform scheduler for {platform_name} not fit")
def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient): def __init__(self: "NoTargetGroup", ctx: ProcessContext):
Platform.__init__(self, ctx, client) Platform.__init__(self, ctx)
self.platform_obj_list = [] self.platform_obj_list = []
for platform_class in self.platform_list: for platform_class in self.platform_list:
self.platform_obj_list.append(platform_class(ctx, client)) self.platform_obj_list.append(platform_class(ctx))
def __str__(self: "NoTargetGroup") -> str: def __str__(self: "NoTargetGroup") -> str:
return "[" + " ".join(x.name for x in self.platform_list) + "]" return "[" + " ".join(x.name for x in self.platform_list) + "]"

View File

@ -46,7 +46,8 @@ class Rss(NewMessage):
return post.id return post.id
async def get_sub_list(self, target: Target) -> list[RawPost]: async def get_sub_list(self, target: Target) -> list[RawPost]:
res = await self.client.get(target, timeout=10.0) client = await self.ctx.get_client()
res = await client.get(target, timeout=10.0)
feed = feedparser.parse(res) feed = feedparser.parse(res)
entries = feed.entries entries = feed.entries
for entry in entries: for entry in entries:

View File

@ -78,8 +78,9 @@ class Weibo(NewMessage):
raise cls.ParseTargetException(prompt="正确格式:\n1. 用户数字UID\n2. https://weibo.com/u/xxxx") raise cls.ParseTargetException(prompt="正确格式:\n1. 用户数字UID\n2. https://weibo.com/u/xxxx")
async def get_sub_list(self, target: Target) -> list[RawPost]: async def get_sub_list(self, target: Target) -> list[RawPost]:
client = await self.ctx.get_client()
params = {"containerid": "107603" + target} params = {"containerid": "107603" + target}
res = await self.client.get("https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0) res = await client.get("https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0)
res_data = json.loads(res.text) res_data = json.loads(res.text)
if not res_data["ok"] and res_data["msg"] != "这里还没有内容": if not res_data["ok"] and res_data["msg"] != "这里还没有内容":
raise ApiError(res.request.url) raise ApiError(res.request.url)
@ -149,7 +150,8 @@ class Weibo(NewMessage):
async def _get_long_weibo(self, weibo_id: str) -> dict: async def _get_long_weibo(self, weibo_id: str) -> dict:
try: try:
weibo_info = await self.client.get( client = await self.ctx.get_client()
weibo_info = await client.get(
"https://m.weibo.cn/statuses/show", "https://m.weibo.cn/statuses/show",
params={"id": weibo_id}, params={"id": weibo_id},
headers=_HEADER, headers=_HEADER,

View File

@ -5,6 +5,8 @@ from nonebot.log import logger
from nonebot_plugin_apscheduler import scheduler from nonebot_plugin_apscheduler import scheduler
from nonebot_plugin_saa.utils.exceptions import NoBotFound from nonebot_plugin_saa.utils.exceptions import NoBotFound
from nonebot_bison.utils.scheduler_config import ClientManager
from ..config import config from ..config import config
from ..send import send_msgs from ..send import send_msgs
from ..types import Target, SubUnit from ..types import Target, SubUnit
@ -24,6 +26,7 @@ class Scheduler:
schedulable_list: list[Schedulable] # for load weigth from db schedulable_list: list[Schedulable] # for load weigth from db
batch_api_target_cache: dict[str, dict[Target, list[Target]]] # platform_name -> (target -> [target]) batch_api_target_cache: dict[str, dict[Target, list[Target]]] # platform_name -> (target -> [target])
batch_platform_name_targets_cache: dict[str, list[Target]] batch_platform_name_targets_cache: dict[str, list[Target]]
client_mgr: ClientManager
def __init__( def __init__(
self, self,
@ -36,6 +39,7 @@ class Scheduler:
logger.error(f"scheduler config [{self.name}] not found, exiting") logger.error(f"scheduler config [{self.name}] not found, exiting")
raise RuntimeError(f"{self.name} not found") raise RuntimeError(f"{self.name} not found")
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.client_mgr = scheduler_config.client_mgr()
self.scheduler_config_obj = self.scheduler_config() self.scheduler_config_obj = self.scheduler_config()
self.schedulable_list = [] self.schedulable_list = []
@ -83,16 +87,14 @@ class Scheduler:
return cur_max_schedulable return cur_max_schedulable
async def exec_fetch(self): async def exec_fetch(self):
context = ProcessContext()
if not (schedulable := await self.get_next_schedulable()): if not (schedulable := await self.get_next_schedulable()):
return return
logger.trace(f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}") logger.trace(f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}")
client = await self.scheduler_config_obj.get_client(schedulable.target) context = ProcessContext(self.client_mgr)
context.register_to_client(client)
try: try:
platform_obj = platform_manager[schedulable.platform_name](context, client) platform_obj = platform_manager[schedulable.platform_name](context)
if schedulable.use_batch: if schedulable.use_batch:
batch_targets = self.batch_api_target_cache[schedulable.platform_name][schedulable.target] batch_targets = self.batch_api_target_cache[schedulable.platform_name][schedulable.target]
sub_units = [] sub_units = []

View File

@ -42,11 +42,12 @@ class BasicTheme(Theme):
if urls: if urls:
text += "\n".join(urls) text += "\n".join(urls)
client = await post.platform.ctx.get_client_for_static()
msgs: list[MessageSegmentFactory] = [Text(text)] msgs: list[MessageSegmentFactory] = [Text(text)]
if post.images: if post.images:
pics = post.images pics = post.images
if is_pics_mergable(pics): if is_pics_mergable(pics):
pics = await pic_merge(list(pics), post.platform.client) pics = await pic_merge(list(pics), client)
msgs.extend(map(Image, pics)) msgs.extend(map(Image, pics))
return msgs return msgs

View File

@ -29,11 +29,12 @@ class BriefTheme(Theme):
if urls: if urls:
text += "\n".join(urls) text += "\n".join(urls)
client = await post.platform.ctx.get_client_for_static()
msgs: list[MessageSegmentFactory] = [Text(text)] msgs: list[MessageSegmentFactory] = [Text(text)]
if post.images: if post.images:
pics = post.images pics = post.images
if is_pics_mergable(pics): if is_pics_mergable(pics):
pics = await pic_merge(list(pics), post.platform.client) pics = await pic_merge(list(pics), client)
msgs.append(Image(pics[0])) msgs.append(Image(pics[0]))
return msgs return msgs

View File

@ -54,9 +54,10 @@ class Ht2iTheme(Theme):
msgs.append(Text("\n".join(urls))) msgs.append(Text("\n".join(urls)))
if post.images: if post.images:
client = await post.platform.ctx.get_client_for_static()
pics = post.images pics = post.images
if is_pics_mergable(pics): if is_pics_mergable(pics):
pics = await pic_merge(list(pics), post.platform.client) pics = await pic_merge(list(pics), client)
msgs.extend(map(Image, pics)) msgs.extend(map(Image, pics))
return msgs return msgs

View File

@ -11,14 +11,16 @@ from nonebot_plugin_saa import Text, Image, MessageSegmentFactory
from .http import http_client from .http import http_client
from .context import ProcessContext from .context import ProcessContext
from ..plugin_config import plugin_config from ..plugin_config import plugin_config
from .scheduler_config import SchedulerConfig, scheduler
from .image import pic_merge, text_to_image, is_pics_mergable, pic_url_to_image from .image import pic_merge, text_to_image, is_pics_mergable, pic_url_to_image
from .scheduler_config import ClientManager, SchedulerConfig, DefaultClientManager, scheduler
__all__ = [ __all__ = [
"http_client", "http_client",
"Singleton", "Singleton",
"parse_text", "parse_text",
"ProcessContext", "ProcessContext",
"ClientManager",
"DefaultClientManager",
"html_to_text", "html_to_text",
"SchedulerConfig", "SchedulerConfig",
"scheduler", "scheduler",

View File

@ -2,19 +2,25 @@ from base64 import b64encode
from httpx import Response, AsyncClient from httpx import Response, AsyncClient
from nonebot_bison.types import Target
from .scheduler_config import ClientManager
class ProcessContext: class ProcessContext:
reqs: list[Response] reqs: list[Response]
_client_mgr: ClientManager
def __init__(self) -> None: def __init__(self, client_mgr: ClientManager) -> None:
self.reqs = [] self.reqs = []
self._client_mgr = client_mgr
def log_response(self, resp: Response): def _log_response(self, resp: Response):
self.reqs.append(resp) self.reqs.append(resp)
def register_to_client(self, client: AsyncClient): def _register_to_client(self, client: AsyncClient):
async def _log_to_ctx(r: Response): async def _log_to_ctx(r: Response):
self.log_response(r) self._log_response(r)
hooks = { hooks = {
"response": [_log_to_ctx], "response": [_log_to_ctx],
@ -41,3 +47,16 @@ class ProcessContext:
) )
res.append(log_content) res.append(log_content)
return res return res
async def get_client(self, target: Target | None = None) -> AsyncClient:
client = await self._client_mgr.get_client(target)
self._register_to_client(client)
return client
async def get_client_for_static(self) -> AsyncClient:
client = await self._client_mgr.get_client_for_static()
self._register_to_client(client)
return client
async def refresh_client(self):
await self._client_mgr.refresh_client()

View File

@ -1,3 +1,4 @@
from abc import ABC
from typing import Literal from typing import Literal
from httpx import AsyncClient from httpx import AsyncClient
@ -6,10 +7,32 @@ from ..types import Target
from .http import http_client from .http import http_client
class ClientManager(ABC):
async def get_client(self, target: Target | None) -> AsyncClient: ...
async def get_client_for_static(self) -> AsyncClient: ...
async def get_query_name_client(self) -> AsyncClient: ...
async def refresh_client(self): ...
class DefaultClientManager(ClientManager):
async def get_client(self, target: Target | None) -> AsyncClient:
return http_client()
async def get_client_for_static(self) -> AsyncClient:
return http_client()
async def get_query_name_client(self) -> AsyncClient:
return http_client()
class SchedulerConfig: class SchedulerConfig:
schedule_type: Literal["date", "interval", "cron"] schedule_type: Literal["date", "interval", "cron"]
schedule_setting: dict schedule_setting: dict
name: str name: str
client_mgr: type[ClientManager] = DefaultClientManager
require_browser: bool = False require_browser: bool = False
def __str__(self): def __str__(self):
@ -18,12 +41,6 @@ class SchedulerConfig:
def __init__(self): def __init__(self):
self.default_http_client = http_client() self.default_http_client = http_client()
async def get_client(self, target: Target) -> AsyncClient:
return self.default_http_client
async def get_query_name_client(self) -> AsyncClient:
return self.default_http_client
def scheduler(schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict) -> type[SchedulerConfig]: def scheduler(schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict) -> type[SchedulerConfig]:
return type( return type(
@ -32,5 +49,6 @@ def scheduler(schedule_type: Literal["date", "interval", "cron"], schedule_setti
{ {
"schedule_type": schedule_type, "schedule_type": schedule_type,
"schedule_setting": schedule_setting, "schedule_setting": schedule_setting,
"client_mgr": ClientManager,
}, },
) )

View File

@ -2,8 +2,8 @@ from time import time
import respx import respx
import pytest import pytest
from httpx import Response
from nonebug.app import App from nonebug.app import App
from httpx import Response, AsyncClient
from nonebot.compat import model_dump, type_validate_python from nonebot.compat import model_dump, type_validate_python
from .utils import get_file, get_json from .utils import get_file, get_json
@ -13,8 +13,9 @@ from .utils import get_file, get_json
def arknights(app: App): def arknights(app: App):
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils.scheduler_config import DefaultClientManager
return platform_manager["arknights"](ProcessContext(), AsyncClient()) return platform_manager["arknights"](ProcessContext(DefaultClientManager()))
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -44,9 +45,8 @@ def monster_siren_list_1():
@respx.mock @respx.mock
async def test_url_parse(app: App): async def test_url_parse(app: App):
from httpx import AsyncClient
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.utils.scheduler_config import DefaultClientManager
from nonebot_bison.platform.arknights import Arknights, BulletinData, BulletinListItem, ArkBulletinResponse from nonebot_bison.platform.arknights import Arknights, BulletinData, BulletinListItem, ArkBulletinResponse
cid_router = respx.get("https://ak-webview.hypergryph.com/api/game/bulletin/1") cid_router = respx.get("https://ak-webview.hypergryph.com/api/game/bulletin/1")
@ -93,7 +93,7 @@ async def test_url_parse(app: App):
b4 = make_bulletin_obj("http://www.baidu.com") b4 = make_bulletin_obj("http://www.baidu.com")
assert b4.jump_link == "http://www.baidu.com" assert b4.jump_link == "http://www.baidu.com"
ark = Arknights(ProcessContext(), AsyncClient()) ark = Arknights(ProcessContext(DefaultClientManager()))
cid_router.mock(return_value=make_response(b1)) cid_router.mock(return_value=make_response(b1))
p1 = await ark.parse(make_bulletin_list_item_obj()) p1 = await ark.parse(make_bulletin_list_item_obj())
@ -115,9 +115,10 @@ async def test_url_parse(app: App):
@pytest.mark.asyncio() @pytest.mark.asyncio()
async def test_get_date_in_bulletin(app: App): async def test_get_date_in_bulletin(app: App):
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.utils.scheduler_config import DefaultClientManager
from nonebot_bison.platform.arknights import Arknights, BulletinListItem from nonebot_bison.platform.arknights import Arknights, BulletinListItem
arknights = Arknights(ProcessContext(), AsyncClient()) arknights = Arknights(ProcessContext(DefaultClientManager()))
assert ( assert (
arknights.get_date( arknights.get_date(
BulletinListItem( BulletinListItem(
@ -136,13 +137,14 @@ async def test_get_date_in_bulletin(app: App):
@pytest.mark.asyncio() @pytest.mark.asyncio()
@respx.mock @respx.mock
async def test_parse_with_breakline(app: App): async def test_parse_with_breakline(app: App):
from nonebot_bison.utils import ProcessContext, http_client from nonebot_bison.utils import ProcessContext
from nonebot_bison.utils.scheduler_config import DefaultClientManager
from nonebot_bison.platform.arknights import Arknights, BulletinListItem from nonebot_bison.platform.arknights import Arknights, BulletinListItem
detail = get_json("arknights-detail-805") detail = get_json("arknights-detail-805")
detail["data"]["header"] = "" detail["data"]["header"] = ""
arknights = Arknights(ProcessContext(), http_client()) arknights = Arknights(ProcessContext(DefaultClientManager()))
router = respx.get("https://ak-webview.hypergryph.com/api/game/bulletin/1") router = respx.get("https://ak-webview.hypergryph.com/api/game/bulletin/1")
router.mock(return_value=Response(200, json=detail)) router.mock(return_value=Response(200, json=detail))

View File

@ -3,8 +3,8 @@ from datetime import datetime
import respx import respx
import pytest import pytest
from httpx import Response
from nonebug.app import App from nonebug.app import App
from httpx import Response, AsyncClient
from nonebot.compat import model_dump, type_validate_python from nonebot.compat import model_dump, type_validate_python
from .utils import get_json from .utils import get_json
@ -25,8 +25,9 @@ if typing.TYPE_CHECKING:
def bilibili(app: App) -> "Bilibili": def bilibili(app: App) -> "Bilibili":
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils.scheduler_config import DefaultClientManager
return platform_manager["bilibili"](ProcessContext(), AsyncClient()) # type: ignore return platform_manager["bilibili"](ProcessContext(DefaultClientManager())) # type: ignore
@pytest.fixture() @pytest.fixture()

View File

@ -2,8 +2,8 @@ import typing
import respx import respx
import pytest import pytest
from httpx import Response
from nonebug.app import App from nonebug.app import App
from httpx import Response, AsyncClient
from .utils import get_json from .utils import get_json
@ -15,8 +15,9 @@ if typing.TYPE_CHECKING:
def bili_bangumi(app: App): def bili_bangumi(app: App):
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils.scheduler_config import DefaultClientManager
return platform_manager["bilibili-bangumi"](ProcessContext(), AsyncClient()) return platform_manager["bilibili-bangumi"](ProcessContext(DefaultClientManager()))
async def test_parse_target(bili_bangumi: "BilibiliBangumi"): async def test_parse_target(bili_bangumi: "BilibiliBangumi"):

View File

@ -3,8 +3,8 @@ from typing import TYPE_CHECKING
import respx import respx
import pytest import pytest
from httpx import Response
from nonebug.app import App from nonebug.app import App
from httpx import Response, AsyncClient
from .utils import get_json from .utils import get_json
@ -16,8 +16,9 @@ if TYPE_CHECKING:
def bili_live(app: App): def bili_live(app: App):
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.platform.bilibili import BilibiliClient
return platform_manager["bilibili-live"](ProcessContext(), AsyncClient()) return platform_manager["bilibili-live"](ProcessContext(BilibiliClient()))
@pytest.fixture() @pytest.fixture()
@ -30,27 +31,6 @@ def dummy_only_open_user_subinfo(app: App):
return UserSubInfo(user=user, categories=[1], tags=[]) return UserSubInfo(user=user, categories=[1], tags=[])
@pytest.mark.asyncio
async def test_http_client_equal(app: App):
from nonebot_bison.types import Target
from nonebot_bison.utils import ProcessContext
from nonebot_bison.platform import platform_manager
empty_target = Target("0")
bilibili = platform_manager["bilibili"](ProcessContext(), AsyncClient())
bilibili_live = platform_manager["bilibili-live"](ProcessContext(), AsyncClient())
bilibili_scheduler = bilibili.scheduler()
bilibili_live_scheduler = bilibili_live.scheduler()
assert await bilibili_scheduler.get_client(empty_target) == await bilibili_live_scheduler.get_client(empty_target)
assert await bilibili_live_scheduler.get_client(empty_target) != bilibili_live_scheduler.default_http_client
assert await bilibili_scheduler.get_query_name_client() == await bilibili_live_scheduler.get_query_name_client()
assert await bilibili_scheduler.get_query_name_client() != bilibili_live_scheduler.default_http_client
@pytest.mark.asyncio @pytest.mark.asyncio
@respx.mock @respx.mock
async def test_fetch_bililive_no_room(bili_live, dummy_only_open_user_subinfo): async def test_fetch_bililive_no_room(bili_live, dummy_only_open_user_subinfo):

View File

@ -1,7 +1,7 @@
import respx import respx
import pytest import pytest
from httpx import Response
from nonebug.app import App from nonebug.app import App
from httpx import Response, AsyncClient
from .utils import get_json from .utils import get_json
@ -10,8 +10,9 @@ from .utils import get_json
def ff14(app: App): def ff14(app: App):
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils.scheduler_config import DefaultClientManager
return platform_manager["ff14"](ProcessContext(), AsyncClient()) return platform_manager["ff14"](ProcessContext(DefaultClientManager()))
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -3,8 +3,8 @@ import typing
import respx import respx
import pytest import pytest
from httpx import Response
from nonebug.app import App from nonebug.app import App
from httpx import Response, AsyncClient
from .utils import get_json from .utils import get_json
@ -16,8 +16,9 @@ if typing.TYPE_CHECKING:
def ncm_artist(app: App): def ncm_artist(app: App):
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils.scheduler_config import DefaultClientManager
return platform_manager["ncm-artist"](ProcessContext(), AsyncClient()) return platform_manager["ncm-artist"](ProcessContext(DefaultClientManager()))
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -3,8 +3,8 @@ import typing
import respx import respx
import pytest import pytest
from httpx import Response
from nonebug.app import App from nonebug.app import App
from httpx import Response, AsyncClient
from .utils import get_json from .utils import get_json
@ -16,8 +16,9 @@ if typing.TYPE_CHECKING:
def ncm_radio(app: App): def ncm_radio(app: App):
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils.scheduler_config import DefaultClientManager
return platform_manager["ncm-radio"](ProcessContext(), AsyncClient()) return platform_manager["ncm-radio"](ProcessContext(DefaultClientManager()))
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -3,7 +3,6 @@ from typing import Any
import pytest import pytest
from nonebug.app import App from nonebug.app import App
from httpx import AsyncClient
now = time() now = time()
passed = now - 3 * 60 * 60 passed = now - 3 * 60 * 60
@ -326,12 +325,13 @@ def mock_status_change(app: App):
async def test_new_message_target_without_cats_tags(mock_platform_without_cats_tags, user_info_factory): async def test_new_message_target_without_cats_tags(mock_platform_without_cats_tags, user_info_factory):
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.types import Target, SubUnit from nonebot_bison.types import Target, SubUnit
from nonebot_bison.utils.scheduler_config import DefaultClientManager
res1 = await mock_platform_without_cats_tags(ProcessContext(), AsyncClient()).fetch_new_post( res1 = await mock_platform_without_cats_tags(ProcessContext(DefaultClientManager())).fetch_new_post(
SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) SubUnit(Target("dummy"), [user_info_factory([1, 2], [])])
) )
assert len(res1) == 0 assert len(res1) == 0
res2 = await mock_platform_without_cats_tags(ProcessContext(), AsyncClient()).fetch_new_post( res2 = await mock_platform_without_cats_tags(ProcessContext(DefaultClientManager())).fetch_new_post(
SubUnit(Target("dummy"), [user_info_factory([], [])]), SubUnit(Target("dummy"), [user_info_factory([], [])]),
) )
assert len(res2) == 1 assert len(res2) == 1
@ -347,12 +347,13 @@ async def test_new_message_target_without_cats_tags(mock_platform_without_cats_t
async def test_new_message_target(mock_platform, user_info_factory): async def test_new_message_target(mock_platform, user_info_factory):
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.types import Target, SubUnit from nonebot_bison.types import Target, SubUnit
from nonebot_bison.utils.scheduler_config import DefaultClientManager
res1 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post( res1 = await mock_platform(ProcessContext(DefaultClientManager())).fetch_new_post(
SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) SubUnit(Target("dummy"), [user_info_factory([1, 2], [])])
) )
assert len(res1) == 0 assert len(res1) == 0
res2 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post( res2 = await mock_platform(ProcessContext(DefaultClientManager())).fetch_new_post(
SubUnit( SubUnit(
Target("dummy"), Target("dummy"),
[ [
@ -382,12 +383,13 @@ async def test_new_message_target(mock_platform, user_info_factory):
async def test_new_message_no_target(mock_platform_no_target, user_info_factory): async def test_new_message_no_target(mock_platform_no_target, user_info_factory):
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.types import Target, SubUnit from nonebot_bison.types import Target, SubUnit
from nonebot_bison.utils.scheduler_config import DefaultClientManager
res1 = await mock_platform_no_target(ProcessContext(), AsyncClient()).fetch_new_post( res1 = await mock_platform_no_target(ProcessContext(DefaultClientManager())).fetch_new_post(
SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) SubUnit(Target("dummy"), [user_info_factory([1, 2], [])])
) )
assert len(res1) == 0 assert len(res1) == 0
res2 = await mock_platform_no_target(ProcessContext(), AsyncClient()).fetch_new_post( res2 = await mock_platform_no_target(ProcessContext(DefaultClientManager())).fetch_new_post(
SubUnit( SubUnit(
Target("dummy"), Target("dummy"),
[ [
@ -411,7 +413,7 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory)
assert "p3" in id_set_1 assert "p3" in id_set_1
assert "p2" in id_set_2 assert "p2" in id_set_2
assert "p2" in id_set_3 assert "p2" in id_set_3
res3 = await mock_platform_no_target(ProcessContext(), AsyncClient()).fetch_new_post( res3 = await mock_platform_no_target(ProcessContext(DefaultClientManager())).fetch_new_post(
SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) SubUnit(Target("dummy"), [user_info_factory([1, 2], [])])
) )
assert len(res3) == 0 assert len(res3) == 0
@ -421,19 +423,20 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory)
async def test_status_change(mock_status_change, user_info_factory): async def test_status_change(mock_status_change, user_info_factory):
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.types import Target, SubUnit from nonebot_bison.types import Target, SubUnit
from nonebot_bison.utils.scheduler_config import DefaultClientManager
res1 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( res1 = await mock_status_change(ProcessContext(DefaultClientManager())).fetch_new_post(
SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) SubUnit(Target("dummy"), [user_info_factory([1, 2], [])])
) )
assert len(res1) == 0 assert len(res1) == 0
res2 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( res2 = await mock_status_change(ProcessContext(DefaultClientManager())).fetch_new_post(
SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) SubUnit(Target("dummy"), [user_info_factory([1, 2], [])])
) )
assert len(res2) == 1 assert len(res2) == 1
posts = res2[0][1] posts = res2[0][1]
assert len(posts) == 1 assert len(posts) == 1
assert posts[0].content == "on" assert posts[0].content == "on"
res3 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( res3 = await mock_status_change(ProcessContext(DefaultClientManager())).fetch_new_post(
SubUnit( SubUnit(
Target("dummy"), Target("dummy"),
[ [
@ -446,7 +449,7 @@ async def test_status_change(mock_status_change, user_info_factory):
assert len(res3[0][1]) == 1 assert len(res3[0][1]) == 1
assert res3[0][1][0].content == "off" assert res3[0][1][0].content == "off"
assert len(res3[1][1]) == 0 assert len(res3[1][1]) == 0
res4 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( res4 = await mock_status_change(ProcessContext(DefaultClientManager())).fetch_new_post(
SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) SubUnit(Target("dummy"), [user_info_factory([1, 2], [])])
) )
assert len(res4) == 0 assert len(res4) == 0
@ -459,14 +462,15 @@ async def test_group(
mock_platform_no_target_2, mock_platform_no_target_2,
user_info_factory, user_info_factory,
): ):
from nonebot_bison.utils import ProcessContext
from nonebot_bison.types import Target, SubUnit from nonebot_bison.types import Target, SubUnit
from nonebot_bison.utils import ProcessContext, http_client
from nonebot_bison.platform.platform import make_no_target_group from nonebot_bison.platform.platform import make_no_target_group
from nonebot_bison.utils.scheduler_config import DefaultClientManager
dummy = Target("dummy") dummy = Target("dummy")
group_platform_class = make_no_target_group([mock_platform_no_target, mock_platform_no_target_2]) group_platform_class = make_no_target_group([mock_platform_no_target, mock_platform_no_target_2])
group_platform = group_platform_class(ProcessContext(), http_client()) group_platform = group_platform_class(ProcessContext(DefaultClientManager()))
res1 = await group_platform.fetch_new_post(SubUnit(dummy, [user_info_factory([1, 4], [])])) res1 = await group_platform.fetch_new_post(SubUnit(dummy, [user_info_factory([1, 4], [])]))
assert len(res1) == 0 assert len(res1) == 0
res2 = await group_platform.fetch_new_post(SubUnit(dummy, [user_info_factory([1, 4], [])])) res2 = await group_platform.fetch_new_post(SubUnit(dummy, [user_info_factory([1, 4], [])]))
@ -487,6 +491,7 @@ async def test_batch_fetch_new_message(app: App):
from nonebot_bison.platform.platform import NewMessage from nonebot_bison.platform.platform import NewMessage
from nonebot_bison.utils.context import ProcessContext from nonebot_bison.utils.context import ProcessContext
from nonebot_bison.types import Target, RawPost, SubUnit, UserSubInfo from nonebot_bison.types import Target, RawPost, SubUnit, UserSubInfo
from nonebot_bison.utils.scheduler_config import DefaultClientManager
class BatchNewMessage(NewMessage): class BatchNewMessage(NewMessage):
platform_name = "mock_platform" platform_name = "mock_platform"
@ -538,7 +543,7 @@ async def test_batch_fetch_new_message(app: App):
user1 = UserSubInfo(TargetQQGroup(group_id=123), [1, 2, 3], []) user1 = UserSubInfo(TargetQQGroup(group_id=123), [1, 2, 3], [])
user2 = UserSubInfo(TargetQQGroup(group_id=234), [1, 2, 3], []) user2 = UserSubInfo(TargetQQGroup(group_id=234), [1, 2, 3], [])
platform_obj = BatchNewMessage(ProcessContext(), None) # type:ignore platform_obj = BatchNewMessage(ProcessContext(DefaultClientManager())) # type:ignore
res1 = await platform_obj.batch_fetch_new_post( res1 = await platform_obj.batch_fetch_new_post(
[ [
@ -572,6 +577,7 @@ async def test_batch_fetch_compare_status(app: App):
from nonebot_bison.post import Post from nonebot_bison.post import Post
from nonebot_bison.utils.context import ProcessContext from nonebot_bison.utils.context import ProcessContext
from nonebot_bison.platform.platform import StatusChange from nonebot_bison.platform.platform import StatusChange
from nonebot_bison.utils.scheduler_config import DefaultClientManager
from nonebot_bison.types import Target, RawPost, SubUnit, Category, UserSubInfo from nonebot_bison.types import Target, RawPost, SubUnit, Category, UserSubInfo
class BatchStatusChange(StatusChange): class BatchStatusChange(StatusChange):
@ -612,7 +618,7 @@ async def test_batch_fetch_compare_status(app: App):
def get_category(self, raw_post): def get_category(self, raw_post):
return raw_post["cat"] return raw_post["cat"]
batch_status_change = BatchStatusChange(ProcessContext(), None) # type: ignore batch_status_change = BatchStatusChange(ProcessContext(DefaultClientManager()))
user1 = UserSubInfo(TargetQQGroup(group_id=123), [1, 2, 3], []) user1 = UserSubInfo(TargetQQGroup(group_id=123), [1, 2, 3], [])
user2 = UserSubInfo(TargetQQGroup(group_id=234), [1, 2, 3], []) user2 = UserSubInfo(TargetQQGroup(group_id=234), [1, 2, 3], [])

View File

@ -1,6 +1,5 @@
import pytest import pytest
from nonebug.app import App from nonebug.app import App
from httpx import AsyncClient
from .utils import get_json from .utils import get_json
@ -15,8 +14,9 @@ def test_cases():
async def test_filter_user_custom_tag(app: App, test_cases): async def test_filter_user_custom_tag(app: App, test_cases):
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils.scheduler_config import DefaultClientManager
bilibili = platform_manager["bilibili"](ProcessContext(), AsyncClient()) bilibili = platform_manager["bilibili"](ProcessContext(DefaultClientManager()))
for case in test_cases: for case in test_cases:
res = bilibili.is_banned_post(**case["case"]) res = bilibili.is_banned_post(**case["case"])
assert res == case["result"] assert res == case["result"]
@ -27,8 +27,9 @@ async def test_filter_user_custom_tag(app: App, test_cases):
async def test_tag_separator(app: App): async def test_tag_separator(app: App):
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils.scheduler_config import DefaultClientManager
bilibili = platform_manager["bilibili"](ProcessContext(), AsyncClient()) bilibili = platform_manager["bilibili"](ProcessContext(DefaultClientManager()))
tags = ["~111", "222", "333", "~444", "555"] tags = ["~111", "222", "333", "~444", "555"]
res = bilibili.tag_separator(tags) res = bilibili.tag_separator(tags)
assert res[0] == ["222", "333", "555"] assert res[0] == ["222", "333", "555"]

View File

@ -5,8 +5,8 @@ import xml.etree.ElementTree as ET
import pytz import pytz
import respx import respx
import pytest import pytest
from httpx import Response
from nonebug.app import App from nonebug.app import App
from httpx import Response, AsyncClient
from .utils import get_file from .utils import get_file
@ -36,8 +36,9 @@ def user_info_factory(app: App, dummy_user):
def rss(app: App): def rss(app: App):
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils.scheduler_config import DefaultClientManager
return platform_manager["rss"](ProcessContext(), AsyncClient()) return platform_manager["rss"](ProcessContext(DefaultClientManager()))
@pytest.fixture() @pytest.fixture()

View File

@ -20,8 +20,9 @@ image_cdn_router = respx.route(host__regex=r"wx\d.sinaimg.cn", path__startswith=
def weibo(app: App): def weibo(app: App):
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils.scheduler_config import DefaultClientManager
return platform_manager["weibo"](ProcessContext(), AsyncClient()) return platform_manager["weibo"](ProcessContext(DefaultClientManager()))
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -3,7 +3,6 @@ from typing import Any
import pytest import pytest
from nonebug.app import App from nonebug.app import App
from httpx import AsyncClient
now = time() now = time()
passed = now - 3 * 60 * 60 passed = now - 3 * 60 * 60
@ -173,8 +172,9 @@ async def test_generate_msg(mock_platform):
from nonebot_bison.post import Post from nonebot_bison.post import Post
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.plugin_config import plugin_config from nonebot_bison.plugin_config import plugin_config
from nonebot_bison.utils.scheduler_config import DefaultClientManager
post: Post = await mock_platform(ProcessContext(), AsyncClient()).parse(raw_post_list_1[0]) post: Post = await mock_platform(ProcessContext(DefaultClientManager())).parse(raw_post_list_1[0])
assert post.platform.default_theme == "basic" assert post.platform.default_theme == "basic"
res = await post.generate() res = await post.generate()
assert len(res) == 1 assert len(res) == 1
@ -203,10 +203,11 @@ async def test_msg_segments_convert(mock_platform):
from nonebot_bison.post import Post from nonebot_bison.post import Post
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.plugin_config import plugin_config from nonebot_bison.plugin_config import plugin_config
from nonebot_bison.utils.scheduler_config import DefaultClientManager
plugin_config.bison_use_pic = True plugin_config.bison_use_pic = True
post: Post = await mock_platform(ProcessContext(), AsyncClient()).parse(raw_post_list_1[0]) post: Post = await mock_platform(ProcessContext(DefaultClientManager())).parse(raw_post_list_1[0])
assert post.platform.default_theme == "basic" assert post.platform.default_theme == "basic"
res = await post.generate_messages() res = await post.generate_messages()
assert len(res) == 1 assert len(res) == 1

View File

@ -3,7 +3,6 @@ from datetime import time
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
from nonebug import App from nonebug import App
from httpx import AsyncClient
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
@ -61,11 +60,12 @@ async def test_scheduler_batch_api(init_scheduler, mocker: MockerFixture):
from nonebot_bison.types import Target as T_Target from nonebot_bison.types import Target as T_Target
from nonebot_bison.scheduler.manager import init_scheduler from nonebot_bison.scheduler.manager import init_scheduler
from nonebot_bison.platform.bilibili import BililiveSchedConf from nonebot_bison.platform.bilibili import BililiveSchedConf
from nonebot_bison.utils.scheduler_config import DefaultClientManager
await config.add_subscribe(TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili-live", [], []) await config.add_subscribe(TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili-live", [], [])
await config.add_subscribe(TargetQQGroup(group_id=123), T_Target("t2"), "target2", "bilibili-live", [], []) await config.add_subscribe(TargetQQGroup(group_id=123), T_Target("t2"), "target2", "bilibili-live", [], [])
mocker.patch.object(BililiveSchedConf, "get_client", return_value=AsyncClient()) mocker.patch.object(BililiveSchedConf, "client_man", DefaultClientManager)
await init_scheduler() await init_scheduler()

View File

@ -6,13 +6,14 @@ from nonebug.app import App
@respx.mock @respx.mock
async def test_http_error(app: App): async def test_http_error(app: App):
from nonebot_bison.utils import ProcessContext, http_client from nonebot_bison.utils import ProcessContext, http_client
from nonebot_bison.utils.scheduler_config import DefaultClientManager
example_route = respx.get("https://example.com") example_route = respx.get("https://example.com")
example_route.mock(httpx.Response(403, json={"error": "gg"})) example_route.mock(httpx.Response(403, json={"error": "gg"}))
ctx = ProcessContext() ctx = ProcessContext(DefaultClientManager())
async with http_client() as client: async with http_client() as client:
ctx.register_to_client(client) ctx._register_to_client(client)
await client.get("https://example.com") await client.get("https://example.com")
assert ctx.gen_req_records() == [ assert ctx.gen_req_records() == [

View File

@ -5,7 +5,6 @@ from inspect import cleandoc
import pytest import pytest
from flaky import flaky from flaky import flaky
from nonebug import App from nonebug import App
from httpx import AsyncClient
now = time() now = time()
passed = now - 3 * 60 * 60 passed = now - 3 * 60 * 60
@ -69,9 +68,10 @@ def mock_platform(app: App):
def mock_post(app: App, mock_platform): def mock_post(app: App, mock_platform):
from nonebot_bison.post import Post from nonebot_bison.post import Post
from nonebot_bison.utils import ProcessContext from nonebot_bison.utils import ProcessContext
from nonebot_bison.utils.scheduler_config import DefaultClientManager
return Post( return Post(
m := mock_platform(ProcessContext(), AsyncClient()), m := mock_platform(ProcessContext(DefaultClientManager())),
"text", "text",
title="title", title="title",
images=["http://t.tt/1.jpg"], images=["http://t.tt/1.jpg"],