From e7dcfdee505daa5a7be1bf3a2bd231aa5e002d89 Mon Sep 17 00:00:00 2001 From: felinae98 <731499577@qq.com> Date: Tue, 29 Aug 2023 21:12:42 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20=E6=8F=90=E4=BE=9B=E6=89=B9?= =?UTF-8?q?=E9=87=8F=20api=20=E6=8E=A5=E5=8F=A3=20(#290)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * :construction: add batch api * :construction: support batch in scheduler * :sparkles: use batch api in bilibili-live :white_check_mark: patch platform_manager directly :sparkles: use batch api in bilibili-live :white_check_mark: patch platform_manager directly * :recycle: refactor var name * :bug: fix test * :bug: fix scheduler * :bug: fix test --- nonebot_bison/platform/bilibili.py | 19 +- nonebot_bison/platform/platform.py | 147 +++++++++----- nonebot_bison/scheduler/manager.py | 4 +- nonebot_bison/scheduler/scheduler.py | 52 ++++- nonebot_bison/types.py | 5 + tests/platforms/test_arknights.py | 18 +- tests/platforms/test_bilibili.py | 14 +- tests/platforms/test_bilibili_bangumi.py | 8 +- tests/platforms/test_bilibili_live.py | 72 ++++--- tests/platforms/test_ff14.py | 8 +- tests/platforms/test_ncm_artist.py | 8 +- tests/platforms/test_ncm_radio.py | 8 +- tests/platforms/test_platform.py | 235 ++++++++++++++++++++--- tests/platforms/test_rss.py | 32 +-- tests/platforms/test_weibo.py | 14 +- tests/scheduler/test_scheduler.py | 41 ++++ 16 files changed, 519 insertions(+), 166 deletions(-) diff --git a/nonebot_bison/platform/bilibili.py b/nonebot_bison/platform/bilibili.py index fabd9c2..75d898e 100644 --- a/nonebot_bison/platform/bilibili.py +++ b/nonebot_bison/platform/bilibili.py @@ -201,6 +201,7 @@ class Bilibililive(StatusChange): scheduler = BilibiliSchedConf name = "Bilibili直播" has_target = True + use_batch = True @unique class LiveStatus(Enum): @@ -281,12 +282,11 @@ class Bilibililive(StatusChange): keyframe="", ) - async def get_status(self, target: Target) -> Info: - params = {"uids[]": target} + async def batch_get_status(self, targets: list[Target]) -> list[Info]: # https://github.com/SocialSisterYi/bilibili-API-collect/blob/master/docs/live/info.md#批量查询直播间状态 res = await self.client.get( "https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids", - params=params, + params={"uids[]": targets}, timeout=4.0, ) res_dict = res.json() @@ -294,11 +294,14 @@ class Bilibililive(StatusChange): if res_dict["code"] != 0: raise self.FetchError() - data = res_dict.get("data") - if not data: - return self._gen_empty_info(uid=int(target)) - room_data = data[target] - return self.Info.parse_obj(room_data) + data = res_dict.get("data", {}) + infos = [] + for target in targets: + if target in data.keys(): + infos.append(self.Info.parse_obj(data[target])) + else: + infos.append(self._gen_empty_info(int(target))) + return infos def compare_status(self, _: Target, old_status: Info, new_status: Info) -> list[RawPost]: action = Bilibililive.LiveAction diff --git a/nonebot_bison/platform/platform.py b/nonebot_bison/platform/platform.py index 5dfbf89..5b98968 100644 --- a/nonebot_bison/platform/platform.py +++ b/nonebot_bison/platform/platform.py @@ -2,11 +2,11 @@ import ssl import json import time import typing -from typing import Any from dataclasses import dataclass from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import Collection +from typing import Any, TypeVar, ParamSpec +from collections.abc import Callable, Awaitable, Collection import httpx from httpx import AsyncClient @@ -16,7 +16,7 @@ from nonebot_plugin_saa import PlatformTarget from ..post import Post from ..plugin_config import plugin_config from ..utils import ProcessContext, SchedulerConfig -from ..types import Tag, Target, RawPost, Category, UserSubInfo +from ..types import Tag, Target, RawPost, SubUnit, Category class CategoryNotSupport(Exception): @@ -44,6 +44,26 @@ class RegistryMeta(type): super().__init__(name, bases, namespace, **kwargs) +P = ParamSpec("P") +R = TypeVar("R") + + +async def catch_network_error(func: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs) -> R | None: + try: + return await func(*args, **kwargs) + except httpx.RequestError as err: + if plugin_config.bison_show_network_warning: + logger.warning(f"network connection error: {type(err)}, url: {err.request.url}") + return None + except ssl.SSLError as err: + if plugin_config.bison_show_network_warning: + logger.warning(f"ssl error: {err}") + return None + except json.JSONDecodeError as err: + logger.warning(f"json error, parsing: {err.doc}") + raise err + + class PlatformMeta(RegistryMeta): categories: dict[Category, str] store: dict[Target, Any] @@ -75,6 +95,7 @@ class Platform(metaclass=PlatformABCMeta, base=True): registry: list[type["Platform"]] client: AsyncClient reverse_category: dict[str, Category] + use_batch: bool = False @classmethod @abstractmethod @@ -82,25 +103,18 @@ class Platform(metaclass=PlatformABCMeta, base=True): ... @abstractmethod - async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]: + async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]: ... - async def do_fetch_new_post( - self, target: Target, users: list[UserSubInfo] - ) -> list[tuple[PlatformTarget, list[Post]]]: - try: - return await self.fetch_new_post(target, users) - except httpx.RequestError as err: - if plugin_config.bison_show_network_warning: - logger.warning(f"network connection error: {type(err)}, url: {err.request.url}") - return [] - except ssl.SSLError as err: - if plugin_config.bison_show_network_warning: - logger.warning(f"ssl error: {err}") - return [] - except json.JSONDecodeError as err: - logger.warning(f"json error, parsing: {err.doc}") - raise err + async def do_fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]: + return await catch_network_error(self.fetch_new_post, sub_unit) or [] + + @abstractmethod + async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]: + ... + + async def do_batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]: + return await catch_network_error(self.batch_fetch_new_post, sub_units) or [] @abstractmethod async def parse(self, raw_post: RawPost) -> Post: @@ -190,10 +204,10 @@ class Platform(metaclass=PlatformABCMeta, base=True): return res async def dispatch_user_post( - self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo] + self, new_posts: list[RawPost], sub_unit: SubUnit ) -> list[tuple[PlatformTarget, list[Post]]]: res: list[tuple[PlatformTarget, list[Post]]] = [] - for user, cats, required_tags in users: + for user, cats, required_tags in sub_unit.user_sub_infos: user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags) user_post: list[Post] = [] for raw_post in user_raw_post: @@ -235,6 +249,12 @@ class MessageProcess(Platform, abstract=True): @abstractmethod async def get_sub_list(self, target: Target) -> list[RawPost]: "Get post list of the given target" + raise NotImplementedError() + + @abstractmethod + async def batch_get_sub_list(self, targets: list[Target]) -> list[list[RawPost]]: + "Get post list of the given targets" + raise NotImplementedError() @abstractmethod def get_date(self, post: RawPost) -> int | None: @@ -298,9 +318,12 @@ class NewMessage(MessageProcess, abstract=True): self.set_stored_data(target, store) return res - async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]: - post_list = await self.get_sub_list(target) - new_posts = await self.filter_common_with_diff(target, post_list) + async def _handle_new_post( + self, + post_list: list[RawPost], + sub_unit: SubUnit, + ) -> list[tuple[PlatformTarget, list[Post]]]: + new_posts = await self.filter_common_with_diff(sub_unit.sub_target, post_list) if not new_posts: return [] else: @@ -308,14 +331,27 @@ class NewMessage(MessageProcess, abstract=True): logger.info( "fetch new post from {} {}: {}".format( self.platform_name, - target if self.has_target else "-", + sub_unit.sub_target if self.has_target else "-", self.get_id(post), ) ) - res = await self.dispatch_user_post(target, new_posts, users) + res = await self.dispatch_user_post(new_posts, sub_unit) self.parse_cache = {} return res + async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]: + post_list = await self.get_sub_list(sub_unit.sub_target) + return await self._handle_new_post(post_list, sub_unit) + + async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]: + if not self.has_target: + raise RuntimeError("Target without target should not use batch api") # pragma: no cover + posts_set = await self.batch_get_sub_list([x[0] for x in sub_units]) + res = [] + for sub_unit, posts in zip(sub_units, posts_set): + res.extend(await self._handle_new_post(posts, sub_unit)) + return res + class StatusChange(Platform, abstract=True): "Watch a status, and fire a post when status changes" @@ -327,6 +363,10 @@ class StatusChange(Platform, abstract=True): async def get_status(self, target: Target) -> Any: ... + @abstractmethod + async def batch_get_status(self, targets: list[Target]) -> list[Any]: + ... + @abstractmethod def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]: ... @@ -335,34 +375,51 @@ class StatusChange(Platform, abstract=True): async def parse(self, raw_post: RawPost) -> Post: ... - async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]: - try: - new_status = await self.get_status(target) - except self.FetchError as err: - logger.warning(f"fetching {self.name}-{target} error: {err}") - raise + async def _handle_status_change( + self, new_status: Any, sub_unit: SubUnit + ) -> list[tuple[PlatformTarget, list[Post]]]: res = [] - if old_status := self.get_stored_data(target): - diff = self.compare_status(target, old_status, new_status) + if old_status := self.get_stored_data(sub_unit.sub_target): + diff = self.compare_status(sub_unit.sub_target, old_status, new_status) if diff: logger.info( "status changes {} {}: {} -> {}".format( self.platform_name, - target if self.has_target else "-", + sub_unit.sub_target if self.has_target else "-", old_status, new_status, ) ) - res = await self.dispatch_user_post(target, diff, users) - self.set_stored_data(target, new_status) + res = await self.dispatch_user_post(diff, sub_unit) + self.set_stored_data(sub_unit.sub_target, new_status) + return res + + async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]: + try: + new_status = await self.get_status(sub_unit.sub_target) + except self.FetchError as err: + logger.warning(f"fetching {self.name}-{sub_unit.sub_target} error: {err}") + raise + return await self._handle_status_change(new_status, sub_unit) + + async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]: + if not self.has_target: + raise RuntimeError("Target without target should not use batch api") # pragma: no cover + new_statuses = await self.batch_get_status([x[0] for x in sub_units]) + res = [] + for sub_unit, new_status in zip(sub_units, new_statuses): + res.extend(await self._handle_status_change(new_status, sub_unit)) return res -class SimplePost(MessageProcess, abstract=True): +class SimplePost(NewMessage, abstract=True): "Fetch a list of messages, dispatch it to different users" - async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]: - new_posts = await self.get_sub_list(target) + async def _handle_new_post( + self, + new_posts: list[RawPost], + sub_unit: SubUnit, + ) -> list[tuple[PlatformTarget, list[Post]]]: if not new_posts: return [] else: @@ -370,11 +427,11 @@ class SimplePost(MessageProcess, abstract=True): logger.info( "fetch new post from {} {}: {}".format( self.platform_name, - target if self.has_target else "-", + sub_unit.sub_target if self.has_target else "-", self.get_id(post), ) ) - res = await self.dispatch_user_post(target, new_posts, users) + res = await self.dispatch_user_post(new_posts, sub_unit) self.parse_cache = {} return res @@ -422,10 +479,10 @@ def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]: async def get_target_name(cls, client: AsyncClient, target: Target): return await platform_list[0].get_target_name(client, target) - async def fetch_new_post(self: "NoTargetGroup", target: Target, users: list[UserSubInfo]): + async def fetch_new_post(self: "NoTargetGroup", sub_unit: SubUnit): res = defaultdict(list) for platform in self.platform_obj_list: - platform_res = await platform.fetch_new_post(target=target, users=users) + platform_res = await platform.fetch_new_post(sub_unit) for user, posts in platform_res: res[user].extend(posts) return [[key, val] for key, val in res.items()] diff --git a/nonebot_bison/scheduler/manager.py b/nonebot_bison/scheduler/manager.py index 271f1b5..b2ab3cc 100644 --- a/nonebot_bison/scheduler/manager.py +++ b/nonebot_bison/scheduler/manager.py @@ -29,7 +29,9 @@ async def init_scheduler(): for scheduler_config, target_list in _schedule_class_dict.items(): schedulable_args = [] for target in target_list: - schedulable_args.append((target.platform_name, T_Target(target.target))) + schedulable_args.append( + (target.platform_name, T_Target(target.target), platform_manager[target.platform_name].use_batch) + ) platform_name_list = _schedule_class_platform_dict[scheduler_config] scheduler_dict[scheduler_config] = Scheduler(scheduler_config, schedulable_args, platform_name_list) config.register_add_target_hook(handle_insert_new_target) diff --git a/nonebot_bison/scheduler/scheduler.py b/nonebot_bison/scheduler/scheduler.py index bff3791..b1fc530 100644 --- a/nonebot_bison/scheduler/scheduler.py +++ b/nonebot_bison/scheduler/scheduler.py @@ -1,12 +1,13 @@ from dataclasses import dataclass +from collections import defaultdict from nonebot.log import logger from nonebot_plugin_apscheduler import scheduler from nonebot_plugin_saa.utils.exceptions import NoBotFound -from ..types import Target from ..config import config from ..send import send_msgs +from ..types import Target, SubUnit from ..platform import platform_manager from ..utils import ProcessContext, SchedulerConfig @@ -16,15 +17,18 @@ class Schedulable: platform_name: str target: Target current_weight: int + use_batch: bool = False class Scheduler: - schedulable_list: list[Schedulable] + schedulable_list: list[Schedulable] # for load weigth from db + batch_api_target_cache: dict[str, dict[Target, list[Target]]] # platform_name -> (target -> [target]) + batch_platform_name_targets_cache: dict[str, list[Target]] def __init__( self, scheduler_config: type[SchedulerConfig], - schedulables: list[tuple[str, Target]], + schedulables: list[tuple[str, Target, bool]], # [(platform_name, target, use_batch)] platform_name_list: list[str], ): self.name = scheduler_config.name @@ -33,9 +37,17 @@ class Scheduler: raise RuntimeError(f"{self.name} not found") self.scheduler_config = scheduler_config self.scheduler_config_obj = self.scheduler_config() + self.schedulable_list = [] - for platform_name, target in schedulables: - self.schedulable_list.append(Schedulable(platform_name=platform_name, target=target, current_weight=0)) + self.batch_platform_name_targets_cache: dict[str, list[Target]] = defaultdict(list) + for platform_name, target, use_batch in schedulables: + if use_batch: + self.batch_platform_name_targets_cache[platform_name].append(target) + self.schedulable_list.append( + Schedulable(platform_name=platform_name, target=target, current_weight=0, use_batch=use_batch) + ) + self._refresh_batch_api_target_cache() + self.platform_name_list = platform_name_list self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值 logger.info( @@ -48,6 +60,12 @@ class Scheduler: **self.scheduler_config.schedule_setting, ) + def _refresh_batch_api_target_cache(self): + self.batch_api_target_cache = defaultdict(dict) + for platform_name, targets in self.batch_platform_name_targets_cache.items(): + for target in targets: + self.batch_api_target_cache[platform_name][target] = targets + async def get_next_schedulable(self) -> Schedulable | None: if not self.schedulable_list: return None @@ -69,14 +87,24 @@ class Scheduler: if not (schedulable := await self.get_next_schedulable()): return logger.trace(f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}") - send_userinfo_list = await config.get_platform_target_subscribers(schedulable.platform_name, schedulable.target) client = await self.scheduler_config_obj.get_client(schedulable.target) context.register_to_client(client) try: platform_obj = platform_manager[schedulable.platform_name](context, client) - to_send = await platform_obj.do_fetch_new_post(schedulable.target, send_userinfo_list) + if schedulable.use_batch: + batch_targets = self.batch_api_target_cache[schedulable.platform_name][schedulable.target] + sub_units = [] + for batch_target in batch_targets: + userinfo = await config.get_platform_target_subscribers(schedulable.platform_name, batch_target) + sub_units.append(SubUnit(batch_target, userinfo)) + to_send = await platform_obj.do_batch_fetch_new_post(sub_units) + else: + send_userinfo_list = await config.get_platform_target_subscribers( + schedulable.platform_name, schedulable.target + ) + to_send = await platform_obj.do_fetch_new_post(SubUnit(schedulable.target, send_userinfo_list)) except Exception as err: records = context.gen_req_records() for record in records: @@ -101,9 +129,18 @@ class Scheduler: def insert_new_schedulable(self, platform_name: str, target: Target): self.pre_weight_val += 1000 self.schedulable_list.append(Schedulable(platform_name, target, 1000)) + + if platform_manager[platform_name].use_batch: + self.batch_platform_name_targets_cache[platform_name].append(target) + self._refresh_batch_api_target_cache() + logger.info(f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})") def delete_schedulable(self, platform_name, target: Target): + if platform_manager[platform_name].use_batch: + self.batch_platform_name_targets_cache[platform_name].remove(target) + self._refresh_batch_api_target_cache() + if not self.schedulable_list: return to_find_idx = None @@ -114,4 +151,3 @@ class Scheduler: if to_find_idx is not None: deleted_schdulable = self.schedulable_list.pop(to_find_idx) self.pre_weight_val -= deleted_schdulable.current_weight - return diff --git a/nonebot_bison/types.py b/nonebot_bison/types.py index 487f5d0..0d08bfd 100644 --- a/nonebot_bison/types.py +++ b/nonebot_bison/types.py @@ -53,3 +53,8 @@ class ApiError(Exception): def __init__(self, url: URL) -> None: msg = f"api {url} error" super().__init__(msg) + + +class SubUnit(NamedTuple): + sub_target: Target + user_sub_infos: list[UserSubInfo] diff --git a/tests/platforms/test_arknights.py b/tests/platforms/test_arknights.py index 8d4592a..8175a47 100644 --- a/tests/platforms/test_arknights.py +++ b/tests/platforms/test_arknights.py @@ -49,6 +49,8 @@ async def test_fetch_new( monster_siren_list_0, monster_siren_list_1, ): + from nonebot_bison.types import Target, SubUnit + ak_list_router = respx.get("https://ak-webview.hypergryph.com/api/game/bulletinList?target=IOS") detail_router = respx.get("https://ak-webview.hypergryph.com/api/game/bulletin/5716") version_router = respx.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version") @@ -63,14 +65,14 @@ async def test_fetch_new( preannouncement_router.mock(return_value=Response(200, json=get_json("arknights-pre-0.json"))) monster_siren_router.mock(return_value=Response(200, json=monster_siren_list_0)) terra_list.mock(return_value=Response(200, json=get_json("terra-hist-0.json"))) - target = "" - res = await arknights.fetch_new_post(target, [dummy_user_subinfo]) + target = Target("") + res = await arknights.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert ak_list_router.called assert len(res) == 0 assert not detail_router.called mock_data = arknights_list_0 ak_list_router.mock(return_value=Response(200, json=mock_data)) - res3 = await arknights.fetch_new_post(target, [dummy_user_subinfo]) + res3 = await arknights.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert len(res3[0][1]) == 1 assert detail_router.called post = res3[0][1][0] @@ -82,7 +84,7 @@ async def test_fetch_new( # assert(post.pics == ['https://ak-fs.hypergryph.com/announce/images/20210623/e6f49aeb9547a2278678368a43b95b07.jpg']) await post.generate_messages() terra_list.mock(return_value=Response(200, json=get_json("terra-hist-1.json"))) - res = await arknights.fetch_new_post(target, [dummy_user_subinfo]) + res = await arknights.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert len(res) == 1 post = res[0][1][0] assert post.target_type == "terra-historicus" @@ -101,6 +103,8 @@ async def test_send_with_render( monster_siren_list_0, monster_siren_list_1, ): + from nonebot_bison.types import Target, SubUnit + ak_list_router = respx.get("https://ak-webview.hypergryph.com/api/game/bulletinList?target=IOS") detail_router = respx.get("https://ak-webview.hypergryph.com/api/game/bulletin/8397") version_router = respx.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version") @@ -115,14 +119,14 @@ async def test_send_with_render( preannouncement_router.mock(return_value=Response(200, json=get_json("arknights-pre-0.json"))) monster_siren_router.mock(return_value=Response(200, json=monster_siren_list_0)) terra_list.mock(return_value=Response(200, json=get_json("terra-hist-0.json"))) - target = "" - res = await arknights.fetch_new_post(target, [dummy_user_subinfo]) + target = Target("") + res = await arknights.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert ak_list_router.called assert len(res) == 0 assert not detail_router.called mock_data = arknights_list_1 ak_list_router.mock(return_value=Response(200, json=mock_data)) - res3 = await arknights.fetch_new_post(target, [dummy_user_subinfo]) + res3 = await arknights.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert len(res3[0][1]) == 1 assert detail_router.called post = res3[0][1][0] diff --git a/tests/platforms/test_bilibili.py b/tests/platforms/test_bilibili.py index 453f775..0b10cc6 100644 --- a/tests/platforms/test_bilibili.py +++ b/tests/platforms/test_bilibili.py @@ -106,14 +106,16 @@ async def test_dynamic_forward(bilibili, bing_dy_list): @pytest.mark.asyncio @respx.mock async def test_fetch_new_without_dynamic(bilibili, dummy_user_subinfo, without_dynamic): + from nonebot_bison.types import Target, SubUnit + post_router = respx.get( "https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history?host_uid=161775300&offset=0&need_top=0" ) post_router.mock(return_value=Response(200, json=without_dynamic)) bilibili_main_page_router = respx.get("https://www.bilibili.com/") bilibili_main_page_router.mock(return_value=Response(200)) - target = "161775300" - res = await bilibili.fetch_new_post(target, [dummy_user_subinfo]) + target = Target("161775300") + res = await bilibili.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert post_router.called assert len(res) == 0 @@ -121,21 +123,23 @@ async def test_fetch_new_without_dynamic(bilibili, dummy_user_subinfo, without_d @pytest.mark.asyncio @respx.mock async def test_fetch_new(bilibili, dummy_user_subinfo): + from nonebot_bison.types import Target, SubUnit + post_router = respx.get( "https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history?host_uid=161775300&offset=0&need_top=0" ) post_router.mock(return_value=Response(200, json=get_json("bilibili_strange_post-0.json"))) bilibili_main_page_router = respx.get("https://www.bilibili.com/") bilibili_main_page_router.mock(return_value=Response(200)) - target = "161775300" - res = await bilibili.fetch_new_post(target, [dummy_user_subinfo]) + target = Target("161775300") + res = await bilibili.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert post_router.called assert len(res) == 0 mock_data = get_json("bilibili_strange_post.json") mock_data["data"]["cards"][0]["desc"]["timestamp"] = int(datetime.now().timestamp()) post_router.mock(return_value=Response(200, json=mock_data)) - res2 = await bilibili.fetch_new_post(target, [dummy_user_subinfo]) + res2 = await bilibili.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert len(res2[0][1]) == 1 post = res2[0][1][0] assert ( diff --git a/tests/platforms/test_bilibili_bangumi.py b/tests/platforms/test_bilibili_bangumi.py index 6506682..a2d37d7 100644 --- a/tests/platforms/test_bilibili_bangumi.py +++ b/tests/platforms/test_bilibili_bangumi.py @@ -35,7 +35,7 @@ async def test_parse_target(bili_bangumi: "BilibiliBangumi"): @pytest.mark.asyncio @respx.mock async def test_fetch_bilibili_bangumi_status(bili_bangumi: "BilibiliBangumi", dummy_user_subinfo): - from nonebot_bison.types import Target + from nonebot_bison.types import Target, SubUnit bili_bangumi_router = respx.get("https://api.bilibili.com/pgc/review/user?media_id=28235413") bili_bangumi_detail_router = respx.get("https://api.bilibili.com/pgc/view/web/season?season_id=39719") @@ -43,15 +43,15 @@ async def test_fetch_bilibili_bangumi_status(bili_bangumi: "BilibiliBangumi", du bilibili_main_page_router = respx.get("https://www.bilibili.com/") bilibili_main_page_router.mock(return_value=Response(200)) target = Target("28235413") - res = await bili_bangumi.fetch_new_post(target, [dummy_user_subinfo]) + res = await bili_bangumi.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert len(res) == 0 - res = await bili_bangumi.fetch_new_post(target, [dummy_user_subinfo]) + res = await bili_bangumi.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert len(res) == 0 bili_bangumi_router.mock(return_value=Response(200, json=get_json("bilibili-gangumi-hanhua1.json"))) bili_bangumi_detail_router.mock(return_value=Response(200, json=get_json("bilibili-gangumi-hanhua1-detail.json"))) - res2 = await bili_bangumi.fetch_new_post(target, [dummy_user_subinfo]) + res2 = await bili_bangumi.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) post = res2[0][1][0] assert post.target_type == "Bilibili剧集" diff --git a/tests/platforms/test_bilibili_live.py b/tests/platforms/test_bilibili_live.py index 744f58b..1889694 100644 --- a/tests/platforms/test_bilibili_live.py +++ b/tests/platforms/test_bilibili_live.py @@ -29,16 +29,18 @@ def dummy_only_open_user_subinfo(app: App): @pytest.mark.asyncio @respx.mock async def test_fetch_bililive_no_room(bili_live, dummy_only_open_user_subinfo): + from nonebot_bison.types import Target, SubUnit + mock_bili_live_status = get_json("bili_live_status.json") - mock_bili_live_status["data"] = [] + mock_bili_live_status["data"] = {} bili_live_router = respx.get("https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids?uids[]=13164144") bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) bilibili_main_page_router = respx.get("https://www.bilibili.com/") bilibili_main_page_router.mock(return_value=Response(200)) - target = "13164144" - res = await bili_live.fetch_new_post(target, [dummy_only_open_user_subinfo]) + target = Target("13164144") + res = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_open_user_subinfo]))]) assert bili_live_router.call_count == 1 assert len(res) == 0 @@ -46,23 +48,25 @@ async def test_fetch_bililive_no_room(bili_live, dummy_only_open_user_subinfo): @pytest.mark.asyncio @respx.mock async def test_fetch_first_live(bili_live, dummy_only_open_user_subinfo): + from nonebot_bison.types import Target, SubUnit + mock_bili_live_status = get_json("bili_live_status.json") empty_bili_live_status = deepcopy(mock_bili_live_status) - empty_bili_live_status["data"] = [] + empty_bili_live_status["data"] = {} bili_live_router = respx.get("https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids?uids[]=13164144") bili_live_router.mock(return_value=Response(200, json=empty_bili_live_status)) bilibili_main_page_router = respx.get("https://www.bilibili.com/") bilibili_main_page_router.mock(return_value=Response(200)) - target = "13164144" - res = await bili_live.fetch_new_post(target, [dummy_only_open_user_subinfo]) + target = Target("13164144") + res = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_open_user_subinfo]))]) assert bili_live_router.call_count == 1 assert len(res) == 0 mock_bili_live_status["data"][target]["live_status"] = 1 bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res2 = await bili_live.fetch_new_post(target, [dummy_only_open_user_subinfo]) + res2 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_open_user_subinfo]))]) assert bili_live_router.call_count == 2 assert len(res2) == 1 post = res2[0][1][0] @@ -77,6 +81,8 @@ async def test_fetch_first_live(bili_live, dummy_only_open_user_subinfo): @pytest.mark.asyncio @respx.mock async def test_fetch_bililive_only_live_open(bili_live, dummy_only_open_user_subinfo): + from nonebot_bison.types import Target, SubUnit + mock_bili_live_status = get_json("bili_live_status.json") bili_live_router = respx.get("https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids?uids[]=13164144") @@ -85,14 +91,14 @@ async def test_fetch_bililive_only_live_open(bili_live, dummy_only_open_user_sub bilibili_main_page_router = respx.get("https://www.bilibili.com/") bilibili_main_page_router.mock(return_value=Response(200)) - target = "13164144" - res = await bili_live.fetch_new_post(target, [dummy_only_open_user_subinfo]) + target = Target("13164144") + res = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_open_user_subinfo]))]) assert bili_live_router.call_count == 1 assert len(res[0][1]) == 0 # 直播状态更新-上播 mock_bili_live_status["data"][target]["live_status"] = 1 bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res2 = await bili_live.fetch_new_post(target, [dummy_only_open_user_subinfo]) + res2 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_open_user_subinfo]))]) post = res2[0][1][0] assert post.target_type == "Bilibili直播" assert post.text == "[开播] 【Zc】从0挑战到15肉鸽!目前10难度" @@ -103,13 +109,13 @@ async def test_fetch_bililive_only_live_open(bili_live, dummy_only_open_user_sub # 标题变更 mock_bili_live_status["data"][target]["title"] = "【Zc】从0挑战到15肉鸽!目前11难度" bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res3 = await bili_live.fetch_new_post(target, [dummy_only_open_user_subinfo]) + res3 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_open_user_subinfo]))]) assert bili_live_router.call_count == 3 assert len(res3[0][1]) == 0 # 直播状态更新-下播 mock_bili_live_status["data"][target]["live_status"] = 0 bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res4 = await bili_live.fetch_new_post(target, [dummy_only_open_user_subinfo]) + res4 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_open_user_subinfo]))]) assert bili_live_router.call_count == 4 assert len(res4[0][1]) == 0 @@ -127,8 +133,10 @@ def dummy_only_title_user_subinfo(app: App): @pytest.mark.asyncio() @respx.mock async def test_fetch_bililive_only_title_change(bili_live, dummy_only_title_user_subinfo): + from nonebot_bison.types import Target, SubUnit + mock_bili_live_status = get_json("bili_live_status.json") - target = "13164144" + target = Target("13164144") bili_live_router = respx.get("https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids?uids[]=13164144") bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) @@ -136,25 +144,25 @@ async def test_fetch_bililive_only_title_change(bili_live, dummy_only_title_user bilibili_main_page_router = respx.get("https://www.bilibili.com/") bilibili_main_page_router.mock(return_value=Response(200)) - res = await bili_live.fetch_new_post(target, [dummy_only_title_user_subinfo]) + res = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_title_user_subinfo]))]) assert bili_live_router.call_count == 1 assert len(res) == 0 # 未开播前标题变更 mock_bili_live_status["data"][target]["title"] = "【Zc】从0挑战到15肉鸽!目前11难度" bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res0 = await bili_live.fetch_new_post(target, [dummy_only_title_user_subinfo]) + res0 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_title_user_subinfo]))]) assert bili_live_router.call_count == 2 assert len(res0) == 0 # 直播状态更新-上播 mock_bili_live_status["data"][target]["live_status"] = 1 bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res2 = await bili_live.fetch_new_post(target, [dummy_only_title_user_subinfo]) + res2 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_title_user_subinfo]))]) assert bili_live_router.call_count == 3 assert len(res2[0][1]) == 0 # 标题变更 mock_bili_live_status["data"][target]["title"] = "【Zc】从0挑战到15肉鸽!目前12难度" bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res3 = await bili_live.fetch_new_post(target, [dummy_only_title_user_subinfo]) + res3 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_title_user_subinfo]))]) post = res3[0][1][0] assert post.target_type == "Bilibili直播" assert post.text == "[标题更新] 【Zc】从0挑战到15肉鸽!目前12难度" @@ -165,7 +173,7 @@ async def test_fetch_bililive_only_title_change(bili_live, dummy_only_title_user # 直播状态更新-下播 mock_bili_live_status["data"][target]["live_status"] = 0 bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res4 = await bili_live.fetch_new_post(target, [dummy_only_title_user_subinfo]) + res4 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_title_user_subinfo]))]) assert bili_live_router.call_count == 5 assert len(res4[0][1]) == 0 @@ -183,8 +191,10 @@ def dummy_only_close_user_subinfo(app: App): @pytest.mark.asyncio @respx.mock async def test_fetch_bililive_only_close(bili_live, dummy_only_close_user_subinfo): + from nonebot_bison.types import Target, SubUnit + mock_bili_live_status = get_json("bili_live_status.json") - target = "13164144" + target = Target("13164144") bili_live_router = respx.get("https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids?uids[]=13164144") bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) @@ -192,31 +202,31 @@ async def test_fetch_bililive_only_close(bili_live, dummy_only_close_user_subinf bilibili_main_page_router = respx.get("https://www.bilibili.com/") bilibili_main_page_router.mock(return_value=Response(200)) - res = await bili_live.fetch_new_post(target, [dummy_only_close_user_subinfo]) + res = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_close_user_subinfo]))]) assert bili_live_router.call_count == 1 assert len(res) == 0 # 未开播前标题变更 mock_bili_live_status["data"][target]["title"] = "【Zc】从0挑战到15肉鸽!目前11难度" bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res0 = await bili_live.fetch_new_post(target, [dummy_only_close_user_subinfo]) + res0 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_close_user_subinfo]))]) assert bili_live_router.call_count == 2 assert len(res0) == 0 # 直播状态更新-上播 mock_bili_live_status["data"][target]["live_status"] = 1 bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res2 = await bili_live.fetch_new_post(target, [dummy_only_close_user_subinfo]) + res2 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_close_user_subinfo]))]) assert bili_live_router.call_count == 3 assert len(res2[0][1]) == 0 # 标题变更 mock_bili_live_status["data"][target]["title"] = "【Zc】从0挑战到15肉鸽!目前12难度" bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res3 = await bili_live.fetch_new_post(target, [dummy_only_close_user_subinfo]) + res3 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_close_user_subinfo]))]) assert bili_live_router.call_count == 4 assert len(res3[0][1]) == 0 # 直播状态更新-下播 mock_bili_live_status["data"][target]["live_status"] = 0 bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res4 = await bili_live.fetch_new_post(target, [dummy_only_close_user_subinfo]) + res4 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_close_user_subinfo]))]) assert bili_live_router.call_count == 5 post = res4[0][1][0] assert post.target_type == "Bilibili直播" @@ -240,8 +250,10 @@ def dummy_bililive_user_subinfo(app: App): @pytest.mark.asyncio @respx.mock async def test_fetch_bililive_combo(bili_live, dummy_bililive_user_subinfo): + from nonebot_bison.types import Target, SubUnit + mock_bili_live_status = get_json("bili_live_status.json") - target = "13164144" + target = Target("13164144") bili_live_router = respx.get("https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids?uids[]=13164144") bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) @@ -249,19 +261,19 @@ async def test_fetch_bililive_combo(bili_live, dummy_bililive_user_subinfo): bilibili_main_page_router = respx.get("https://www.bilibili.com/") bilibili_main_page_router.mock(return_value=Response(200)) - res = await bili_live.fetch_new_post(target, [dummy_bililive_user_subinfo]) + res = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_bililive_user_subinfo]))]) assert bili_live_router.call_count == 1 assert len(res) == 0 # 未开播前标题变更 mock_bili_live_status["data"][target]["title"] = "【Zc】从0挑战到15肉鸽!目前11难度" bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res0 = await bili_live.fetch_new_post(target, [dummy_bililive_user_subinfo]) + res0 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_bililive_user_subinfo]))]) assert bili_live_router.call_count == 2 assert len(res0) == 0 # 直播状态更新-上播 mock_bili_live_status["data"][target]["live_status"] = 1 bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res2 = await bili_live.fetch_new_post(target, [dummy_bililive_user_subinfo]) + res2 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_bililive_user_subinfo]))]) post2 = res2[0][1][0] assert post2.target_type == "Bilibili直播" assert post2.text == "[开播] 【Zc】从0挑战到15肉鸽!目前11难度" @@ -272,7 +284,7 @@ async def test_fetch_bililive_combo(bili_live, dummy_bililive_user_subinfo): # 标题变更 mock_bili_live_status["data"][target]["title"] = "【Zc】从0挑战到15肉鸽!目前12难度" bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res3 = await bili_live.fetch_new_post(target, [dummy_bililive_user_subinfo]) + res3 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_bililive_user_subinfo]))]) post3 = res3[0][1][0] assert post3.target_type == "Bilibili直播" assert post3.text == "[标题更新] 【Zc】从0挑战到15肉鸽!目前12难度" @@ -283,7 +295,7 @@ async def test_fetch_bililive_combo(bili_live, dummy_bililive_user_subinfo): # 直播状态更新-下播 mock_bili_live_status["data"][target]["live_status"] = 0 bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) - res4 = await bili_live.fetch_new_post(target, [dummy_bililive_user_subinfo]) + res4 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_bililive_user_subinfo]))]) post4 = res4[0][1][0] assert post4.target_type == "Bilibili直播" assert post4.text == "[下播] 【Zc】从0挑战到15肉鸽!目前12难度" diff --git a/tests/platforms/test_ff14.py b/tests/platforms/test_ff14.py index f165d26..9631c70 100644 --- a/tests/platforms/test_ff14.py +++ b/tests/platforms/test_ff14.py @@ -27,16 +27,18 @@ def ff14_newdata_json_1(): @pytest.mark.asyncio @respx.mock async def test_fetch_new(ff14, dummy_user_subinfo, ff14_newdata_json_0, ff14_newdata_json_1): + from nonebot_bison.types import Target, SubUnit + newdata = respx.get( "https://cqnews.web.sdo.com/api/news/newsList?gameCode=ff&CategoryCode=5309,5310,5311,5312,5313&pageIndex=0&pageSize=5" ) newdata.mock(return_value=Response(200, json=ff14_newdata_json_0)) - target = "" - res = await ff14.fetch_new_post(target, [dummy_user_subinfo]) + target = Target("") + res = await ff14.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert newdata.called assert len(res) == 0 newdata.mock(return_value=Response(200, json=ff14_newdata_json_1)) - res = await ff14.fetch_new_post(target, [dummy_user_subinfo]) + res = await ff14.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert newdata.called post = res[0][1][0] assert post.target_type == "ff14" diff --git a/tests/platforms/test_ncm_artist.py b/tests/platforms/test_ncm_artist.py index b50f156..d982229 100644 --- a/tests/platforms/test_ncm_artist.py +++ b/tests/platforms/test_ncm_artist.py @@ -41,14 +41,16 @@ def ncm_artist_1(ncm_artist_raw: dict): @pytest.mark.asyncio @respx.mock async def test_fetch_new(ncm_artist, ncm_artist_0, ncm_artist_1, dummy_user_subinfo): + from nonebot_bison.types import Target, SubUnit + ncm_router = respx.get("https://music.163.com/api/artist/albums/32540734") ncm_router.mock(return_value=Response(200, json=ncm_artist_0)) - target = "32540734" - res = await ncm_artist.fetch_new_post(target, [dummy_user_subinfo]) + target = Target("32540734") + res = await ncm_artist.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert ncm_router.called assert len(res) == 0 ncm_router.mock(return_value=Response(200, json=ncm_artist_1)) - res2 = await ncm_artist.fetch_new_post(target, [dummy_user_subinfo]) + res2 = await ncm_artist.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) post = res2[0][1][0] assert post.target_type == "ncm-artist" assert post.text == "新专辑发布:Y1K" diff --git a/tests/platforms/test_ncm_radio.py b/tests/platforms/test_ncm_radio.py index 7de411e..7547540 100644 --- a/tests/platforms/test_ncm_radio.py +++ b/tests/platforms/test_ncm_radio.py @@ -41,14 +41,16 @@ def ncm_radio_1(ncm_radio_raw: dict): @pytest.mark.asyncio @respx.mock async def test_fetch_new(ncm_radio, ncm_radio_0, ncm_radio_1, dummy_user_subinfo): + from nonebot_bison.types import Target, SubUnit + ncm_router = respx.post("http://music.163.com/api/dj/program/byradio") ncm_router.mock(return_value=Response(200, json=ncm_radio_0)) - target = "793745436" - res = await ncm_radio.fetch_new_post(target, [dummy_user_subinfo]) + target = Target("793745436") + res = await ncm_radio.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert ncm_router.called assert len(res) == 0 ncm_router.mock(return_value=Response(200, json=ncm_radio_1)) - res2 = await ncm_radio.fetch_new_post(target, [dummy_user_subinfo]) + res2 = await ncm_radio.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) post = res2[0][1][0] assert post.target_type == "ncm-radio" assert post.text == "网易云电台更新:「松烟行动」灰齐山麓" diff --git a/tests/platforms/test_platform.py b/tests/platforms/test_platform.py index c8a7462..f2dce29 100644 --- a/tests/platforms/test_platform.py +++ b/tests/platforms/test_platform.py @@ -325,16 +325,14 @@ def mock_status_change(app: App): @pytest.mark.asyncio async def test_new_message_target_without_cats_tags(mock_platform_without_cats_tags, user_info_factory): from nonebot_bison.utils import ProcessContext + from nonebot_bison.types import Target, SubUnit res1 = await mock_platform_without_cats_tags(ProcessContext(), AsyncClient()).fetch_new_post( - "dummy", [user_info_factory([1, 2], [])] + SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) ) assert len(res1) == 0 res2 = await mock_platform_without_cats_tags(ProcessContext(), AsyncClient()).fetch_new_post( - "dummy", - [ - user_info_factory([], []), - ], + SubUnit(Target("dummy"), [user_info_factory([], [])]), ) assert len(res2) == 1 posts_1 = res2[0][1] @@ -348,16 +346,21 @@ async def test_new_message_target_without_cats_tags(mock_platform_without_cats_t @pytest.mark.asyncio async def test_new_message_target(mock_platform, user_info_factory): from nonebot_bison.utils import ProcessContext + from nonebot_bison.types import Target, SubUnit - res1 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post("dummy", [user_info_factory([1, 2], [])]) + res1 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post( + SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) + ) assert len(res1) == 0 res2 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post( - "dummy", - [ - user_info_factory([1, 2], []), - user_info_factory([1], []), - user_info_factory([1, 2], ["tag1"]), - ], + SubUnit( + Target("dummy"), + [ + user_info_factory([1, 2], []), + user_info_factory([1], []), + user_info_factory([1, 2], ["tag1"]), + ], + ), ) assert len(res2) == 3 posts_1 = res2[0][1] @@ -378,18 +381,21 @@ async def test_new_message_target(mock_platform, user_info_factory): @pytest.mark.asyncio async def test_new_message_no_target(mock_platform_no_target, user_info_factory): from nonebot_bison.utils import ProcessContext + from nonebot_bison.types import Target, SubUnit res1 = await mock_platform_no_target(ProcessContext(), AsyncClient()).fetch_new_post( - "dummy", [user_info_factory([1, 2], [])] + SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) ) assert len(res1) == 0 res2 = await mock_platform_no_target(ProcessContext(), AsyncClient()).fetch_new_post( - "dummy", - [ - user_info_factory([1, 2], []), - user_info_factory([1], []), - user_info_factory([1, 2], ["tag1"]), - ], + SubUnit( + Target("dummy"), + [ + user_info_factory([1, 2], []), + user_info_factory([1], []), + user_info_factory([1, 2], ["tag1"]), + ], + ), ) assert len(res2) == 3 posts_1 = res2[0][1] @@ -406,7 +412,7 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory) assert "p2" in id_set_2 assert "p2" in id_set_3 res3 = await mock_platform_no_target(ProcessContext(), AsyncClient()).fetch_new_post( - "dummy", [user_info_factory([1, 2], [])] + SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) ) assert len(res3) == 0 @@ -414,31 +420,34 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory) @pytest.mark.asyncio async def test_status_change(mock_status_change, user_info_factory): from nonebot_bison.utils import ProcessContext + from nonebot_bison.types import Target, SubUnit res1 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( - "dummy", [user_info_factory([1, 2], [])] + SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) ) assert len(res1) == 0 res2 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( - "dummy", [user_info_factory([1, 2], [])] + SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) ) assert len(res2) == 1 posts = res2[0][1] assert len(posts) == 1 assert posts[0].text == "on" res3 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( - "dummy", - [ - user_info_factory([1, 2], []), - user_info_factory([1], []), - ], + SubUnit( + Target("dummy"), + [ + user_info_factory([1, 2], []), + user_info_factory([1], []), + ], + ), ) assert len(res3) == 2 assert len(res3[0][1]) == 1 assert res3[0][1][0].text == "off" assert len(res3[1][1]) == 0 res4 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( - "dummy", [user_info_factory([1, 2], [])] + SubUnit(Target("dummy"), [user_info_factory([1, 2], [])]) ) assert len(res4) == 0 @@ -450,7 +459,7 @@ async def test_group( mock_platform_no_target_2, user_info_factory, ): - from nonebot_bison.types import Target + from nonebot_bison.types import Target, SubUnit from nonebot_bison.utils import ProcessContext, http_client from nonebot_bison.platform.platform import make_no_target_group @@ -458,14 +467,176 @@ async def test_group( group_platform_class = make_no_target_group([mock_platform_no_target, mock_platform_no_target_2]) group_platform = group_platform_class(ProcessContext(), http_client()) - res1 = await group_platform.fetch_new_post(dummy, [user_info_factory([1, 4], [])]) + res1 = await group_platform.fetch_new_post(SubUnit(dummy, [user_info_factory([1, 4], [])])) assert len(res1) == 0 - res2 = await group_platform.fetch_new_post(dummy, [user_info_factory([1, 4], [])]) + res2 = await group_platform.fetch_new_post(SubUnit(dummy, [user_info_factory([1, 4], [])])) assert len(res2) == 1 posts = res2[0][1] assert len(posts) == 2 id_set_2 = {x.text for x in posts} assert "p2" in id_set_2 assert "p6" in id_set_2 - res3 = await group_platform.fetch_new_post(dummy, [user_info_factory([1, 4], [])]) + res3 = await group_platform.fetch_new_post(SubUnit(dummy, [user_info_factory([1, 4], [])])) assert len(res3) == 0 + + +async def test_batch_fetch_new_message(app: App): + from nonebot_plugin_saa import TargetQQGroup + + from nonebot_bison.post import Post + from nonebot_bison.platform.platform import NewMessage + from nonebot_bison.utils.context import ProcessContext + from nonebot_bison.types import Target, RawPost, SubUnit, UserSubInfo + + class BatchNewMessage(NewMessage): + platform_name = "mock_platform" + name = "Mock Platform" + enabled = True + is_common = True + schedule_interval = 10 + enable_tag = False + categories = {} + has_target = True + + sub_index = 0 + + @classmethod + async def get_target_name(cls, client, _: "Target"): + return "MockPlatform" + + def get_id(self, post: "RawPost") -> Any: + return post["id"] + + def get_date(self, raw_post: "RawPost") -> float: + return raw_post["date"] + + async def parse(self, raw_post: "RawPost") -> "Post": + return Post( + "mock_platform", + raw_post["text"], + "http://t.tt/" + str(self.get_id(raw_post)), + target_name="Mock", + ) + + @classmethod + async def batch_get_sub_list(cls, targets: list[Target]) -> list[list[RawPost]]: + assert len(targets) > 1 + if cls.sub_index == 0: + cls.sub_index += 1 + res = [ + [raw_post_list_2[0]], + [raw_post_list_2[1]], + ] + else: + res = [ + [raw_post_list_2[0], raw_post_list_2[2]], + [raw_post_list_2[1], raw_post_list_2[3]], + ] + res += [[]] * (len(targets) - 2) + return res + + user1 = UserSubInfo(TargetQQGroup(group_id=123), [1, 2, 3], []) + user2 = UserSubInfo(TargetQQGroup(group_id=234), [1, 2, 3], []) + + platform_obj = BatchNewMessage(ProcessContext(), None) # type:ignore + + res1 = await platform_obj.batch_fetch_new_post( + [ + SubUnit(Target("target1"), [user1]), + SubUnit(Target("target2"), [user1, user2]), + SubUnit(Target("target3"), [user2]), + ] + ) + assert len(res1) == 0 + + res2 = await platform_obj.batch_fetch_new_post( + [ + SubUnit(Target("target1"), [user1]), + SubUnit(Target("target2"), [user1, user2]), + SubUnit(Target("target3"), [user2]), + ] + ) + assert len(res2) == 3 + send_set = set() + for platform_target, posts in res2: + for post in posts: + send_set.add((platform_target, post.text)) + assert (TargetQQGroup(group_id=123), "p3") in send_set + assert (TargetQQGroup(group_id=123), "p4") in send_set + assert (TargetQQGroup(group_id=234), "p4") in send_set + + +async def test_batch_fetch_compare_status(app: App): + from nonebot_plugin_saa import TargetQQGroup + + from nonebot_bison.post import Post + from nonebot_bison.utils.context import ProcessContext + from nonebot_bison.platform.platform import StatusChange + from nonebot_bison.types import Target, RawPost, SubUnit, Category, UserSubInfo + + class BatchStatusChange(StatusChange): + platform_name = "mock_platform" + name = "Mock Platform" + enabled = True + is_common = True + enable_tag = False + schedule_type = "interval" + schedule_kw = {"seconds": 10} + has_target = False + categories = { + Category(1): "转发", + Category(2): "视频", + } + + sub_index = 0 + + @classmethod + async def batch_get_status(cls, targets: "list[Target]"): + assert len(targets) > 0 + res = [{"s": cls.sub_index == 1} for _ in targets] + res[0]["s"] = not res[0]["s"] + if cls.sub_index == 0: + cls.sub_index += 1 + return res + + def compare_status(self, target, old_status, new_status) -> list["RawPost"]: + if old_status["s"] is False and new_status["s"] is True: + return [{"text": "on", "cat": 1}] + elif old_status["s"] is True and new_status["s"] is False: + return [{"text": "off", "cat": 2}] + return [] + + async def parse(self, raw_post) -> "Post": + return Post("mock_status", raw_post["text"], "") + + def get_category(self, raw_post): + return raw_post["cat"] + + batch_status_change = BatchStatusChange(ProcessContext(), None) # type: ignore + + user1 = UserSubInfo(TargetQQGroup(group_id=123), [1, 2, 3], []) + user2 = UserSubInfo(TargetQQGroup(group_id=234), [1, 2, 3], []) + + res1 = await batch_status_change.batch_fetch_new_post( + [ + SubUnit(Target("target1"), [user1]), + SubUnit(Target("target2"), [user1, user2]), + ] + ) + assert len(res1) == 0 + + res2 = await batch_status_change.batch_fetch_new_post( + [ + SubUnit(Target("target1"), [user1]), + SubUnit(Target("target2"), [user1, user2]), + ] + ) + + send_set = set() + for platform_target, posts in res2: + for post in posts: + send_set.add((platform_target, post.text)) + assert len(send_set) == 3 + assert (TargetQQGroup(group_id=123), "off") in send_set + assert (TargetQQGroup(group_id=123), "on") in send_set + assert (TargetQQGroup(group_id=234), "on") in send_set diff --git a/tests/platforms/test_rss.py b/tests/platforms/test_rss.py index 88a78cf..bb80a2f 100644 --- a/tests/platforms/test_rss.py +++ b/tests/platforms/test_rss.py @@ -68,15 +68,17 @@ async def test_fetch_new_1( user_info_factory, update_time_feed_1, ): + from nonebot_bison.types import Target, SubUnit + ## 标题重复的情况 rss_router = respx.get("https://rsshub.app/twitter/user/ArknightsStaff") rss_router.mock(return_value=Response(200, text=get_file("rss-twitter-ArknightsStaff-0.xml"))) - target = "https://rsshub.app/twitter/user/ArknightsStaff" - res1 = await rss.fetch_new_post(target, [user_info_factory([], [])]) + target = Target("https://rsshub.app/twitter/user/ArknightsStaff") + res1 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])])) assert len(res1) == 0 rss_router.mock(return_value=Response(200, text=update_time_feed_1)) - res2 = await rss.fetch_new_post(target, [user_info_factory([], [])]) + res2 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])])) assert len(res2[0][1]) == 1 post1 = res2[0][1][0] assert post1.url == "https://twitter.com/ArknightsStaff/status/1659091539023282178" @@ -95,15 +97,17 @@ async def test_fetch_new_2( user_info_factory, update_time_feed_2, ): + from nonebot_bison.types import Target, SubUnit + ## 标题与正文不重复的情况 rss_router = respx.get("https://www.ruanyifeng.com/blog/atom.xml") rss_router.mock(return_value=Response(200, text=get_file("rss-ruanyifeng-0.xml"))) - target = "https://www.ruanyifeng.com/blog/atom.xml" - res1 = await rss.fetch_new_post(target, [user_info_factory([], [])]) + target = Target("https://www.ruanyifeng.com/blog/atom.xml") + res1 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])])) assert len(res1) == 0 rss_router.mock(return_value=Response(200, text=update_time_feed_2)) - res2 = await rss.fetch_new_post(target, [user_info_factory([], [])]) + res2 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])])) assert len(res2[0][1]) == 1 post1 = res2[0][1][0] assert post1.url == "http://www.ruanyifeng.com/blog/2023/05/weekly-issue-255.html" @@ -129,15 +133,17 @@ async def test_fetch_new_3( user_info_factory, update_time_feed_3, ): + from nonebot_bison.types import Target, SubUnit + ## 只有没有 rss_router = respx.get("https://github.com/R3nzTheCodeGOD/R3nzSkin/releases.atom") rss_router.mock(return_value=Response(200, text=get_file("rss-github-atom-0.xml"))) - target = "https://github.com/R3nzTheCodeGOD/R3nzSkin/releases.atom" - res1 = await rss.fetch_new_post(target, [user_info_factory([], [])]) + target = Target("https://github.com/R3nzTheCodeGOD/R3nzSkin/releases.atom") + res1 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])])) assert len(res1) == 0 rss_router.mock(return_value=Response(200, text=update_time_feed_3)) - res2 = await rss.fetch_new_post(target, [user_info_factory([], [])]) + res2 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])])) assert len(res2[0][1]) == 1 post1 = res2[0][1][0] assert post1.url == "https://github.com/R3nzTheCodeGOD/R3nzSkin/releases/tag/v3.0.9" @@ -150,15 +156,17 @@ async def test_fetch_new_4( rss, user_info_factory, ): + from nonebot_bison.types import Target, SubUnit + ## 没有日期信息的情况 rss_router = respx.get("https://rsshub.app/wallhaven/hot?limit=5") rss_router.mock(return_value=Response(200, text=get_file("rss-top5-old.xml"))) - target = "https://rsshub.app/wallhaven/hot?limit=5" - res1 = await rss.fetch_new_post(target, [user_info_factory([], [])]) + target = Target("https://rsshub.app/wallhaven/hot?limit=5") + res1 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])])) assert len(res1) == 0 rss_router.mock(return_value=Response(200, text=get_file("rss-top5-new.xml"))) - res2 = await rss.fetch_new_post(target, [user_info_factory([], [])]) + res2 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])])) assert len(res2[0][1]) == 1 post1 = res2[0][1][0] assert post1.url == "https://wallhaven.cc/w/85rjej" diff --git a/tests/platforms/test_weibo.py b/tests/platforms/test_weibo.py index 14148b8..ee6f7cb 100644 --- a/tests/platforms/test_weibo.py +++ b/tests/platforms/test_weibo.py @@ -41,25 +41,27 @@ async def test_get_name(weibo): @pytest.mark.asyncio @respx.mock async def test_fetch_new(weibo, dummy_user_subinfo): + from nonebot_bison.types import Target, SubUnit + ak_list_router = respx.get("https://m.weibo.cn/api/container/getIndex?containerid=1076036279793937") detail_router = respx.get("https://m.weibo.cn/detail/4649031014551911") ak_list_router.mock(return_value=Response(200, json=get_json("weibo_ak_list_0.json"))) detail_router.mock(return_value=Response(200, text=get_file("weibo_detail_4649031014551911"))) image_cdn_router.mock(Response(200, content=b"")) - target = "6279793937" - res = await weibo.fetch_new_post(target, [dummy_user_subinfo]) + target = Target("6279793937") + res = await weibo.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert ak_list_router.called assert len(res) == 0 assert not detail_router.called mock_data = get_json("weibo_ak_list_1.json") ak_list_router.mock(return_value=Response(200, json=mock_data)) - res2 = await weibo.fetch_new_post(target, [dummy_user_subinfo]) + res2 = await weibo.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert len(res2) == 0 mock_data["data"]["cards"][1]["mblog"]["created_at"] = datetime.now(timezone("Asia/Shanghai")).strftime( "%a %b %d %H:%M:%S %z %Y" ) ak_list_router.mock(return_value=Response(200, json=mock_data)) - res3 = await weibo.fetch_new_post(target, [dummy_user_subinfo]) + res3 = await weibo.fetch_new_post(SubUnit(target, [dummy_user_subinfo])) assert len(res3[0][1]) == 1 assert not detail_router.called post = res3[0][1][0] @@ -103,7 +105,9 @@ def test_tag(weibo, weibo_ak_list_1): @pytest.mark.asyncio @pytest.mark.compare async def test_rsshub_compare(weibo): - target = "6279793937" + from nonebot_bison.types import Target + + target = Target("6279793937") raw_posts = filter(weibo.filter_platform_custom, await weibo.get_sub_list(target)) posts = [] for raw_post in raw_posts: diff --git a/tests/scheduler/test_scheduler.py b/tests/scheduler/test_scheduler.py index 56d5419..df50601 100644 --- a/tests/scheduler/test_scheduler.py +++ b/tests/scheduler/test_scheduler.py @@ -1,7 +1,9 @@ import typing from datetime import time +from unittest.mock import AsyncMock from nonebug import App +from httpx import AsyncClient from pytest_mock import MockerFixture if typing.TYPE_CHECKING: @@ -50,6 +52,45 @@ async def test_scheduler_without_time(init_scheduler): assert static_res["bilibili-live-t2"] == 3 +async def test_scheduler_batch_api(init_scheduler, mocker: MockerFixture): + from nonebot_plugin_saa import TargetQQGroup + + from nonebot_bison.config import config + from nonebot_bison.types import UserSubInfo + from nonebot_bison.scheduler import scheduler_dict + from nonebot_bison.types import Target as T_Target + from nonebot_bison.scheduler.manager import init_scheduler + from nonebot_bison.platform.bilibili import BilibiliSchedConf + + await config.add_subscribe(TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili-live", [], []) + await config.add_subscribe(TargetQQGroup(group_id=123), T_Target("t2"), "target2", "bilibili-live", [], []) + + mocker.patch.object(BilibiliSchedConf, "get_client", return_value=AsyncClient()) + + await init_scheduler() + + batch_fetch_mock = AsyncMock() + + class FakePlatform: + def __init__(self) -> None: + self.do_batch_fetch_new_post = batch_fetch_mock + + fake_platform_obj = FakePlatform() + mocker.patch.dict( + "nonebot_bison.scheduler.scheduler.platform_manager", + {"bilibili-live": mocker.Mock(return_value=fake_platform_obj)}, + ) + + await scheduler_dict[BilibiliSchedConf].exec_fetch() + + batch_fetch_mock.assert_called_once_with( + [ + (T_Target("t1"), [UserSubInfo(user=TargetQQGroup(group_id=123), categories=[], tags=[])]), + (T_Target("t2"), [UserSubInfo(user=TargetQQGroup(group_id=123), categories=[], tags=[])]), + ] + ) + + async def test_scheduler_with_time(app: App, init_scheduler, mocker: MockerFixture): from nonebot_plugin_saa import TargetQQGroup