提供批量 api 接口 (#290)

* 🚧 add batch api

* 🚧 support batch in scheduler

*  use batch api in bilibili-live

 patch platform_manager directly

 use batch api in bilibili-live

 patch platform_manager directly

* ♻️ refactor var name

* 🐛 fix test

* 🐛 fix scheduler

* 🐛 fix test
This commit is contained in:
felinae98 2023-08-29 21:12:42 +08:00 committed by GitHub
parent 219e3ba5c6
commit e7dcfdee50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 519 additions and 166 deletions

View File

@ -201,6 +201,7 @@ class Bilibililive(StatusChange):
scheduler = BilibiliSchedConf
name = "Bilibili直播"
has_target = True
use_batch = True
@unique
class LiveStatus(Enum):
@ -281,12 +282,11 @@ class Bilibililive(StatusChange):
keyframe="",
)
async def get_status(self, target: Target) -> Info:
params = {"uids[]": target}
async def batch_get_status(self, targets: list[Target]) -> list[Info]:
# https://github.com/SocialSisterYi/bilibili-API-collect/blob/master/docs/live/info.md#批量查询直播间状态
res = await self.client.get(
"https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids",
params=params,
params={"uids[]": targets},
timeout=4.0,
)
res_dict = res.json()
@ -294,11 +294,14 @@ class Bilibililive(StatusChange):
if res_dict["code"] != 0:
raise self.FetchError()
data = res_dict.get("data")
if not data:
return self._gen_empty_info(uid=int(target))
room_data = data[target]
return self.Info.parse_obj(room_data)
data = res_dict.get("data", {})
infos = []
for target in targets:
if target in data.keys():
infos.append(self.Info.parse_obj(data[target]))
else:
infos.append(self._gen_empty_info(int(target)))
return infos
def compare_status(self, _: Target, old_status: Info, new_status: Info) -> list[RawPost]:
action = Bilibililive.LiveAction

View File

@ -2,11 +2,11 @@ import ssl
import json
import time
import typing
from typing import Any
from dataclasses import dataclass
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Collection
from typing import Any, TypeVar, ParamSpec
from collections.abc import Callable, Awaitable, Collection
import httpx
from httpx import AsyncClient
@ -16,7 +16,7 @@ from nonebot_plugin_saa import PlatformTarget
from ..post import Post
from ..plugin_config import plugin_config
from ..utils import ProcessContext, SchedulerConfig
from ..types import Tag, Target, RawPost, Category, UserSubInfo
from ..types import Tag, Target, RawPost, SubUnit, Category
class CategoryNotSupport(Exception):
@ -44,6 +44,26 @@ class RegistryMeta(type):
super().__init__(name, bases, namespace, **kwargs)
P = ParamSpec("P")
R = TypeVar("R")
async def catch_network_error(func: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs) -> R | None:
try:
return await func(*args, **kwargs)
except httpx.RequestError as err:
if plugin_config.bison_show_network_warning:
logger.warning(f"network connection error: {type(err)}, url: {err.request.url}")
return None
except ssl.SSLError as err:
if plugin_config.bison_show_network_warning:
logger.warning(f"ssl error: {err}")
return None
except json.JSONDecodeError as err:
logger.warning(f"json error, parsing: {err.doc}")
raise err
class PlatformMeta(RegistryMeta):
categories: dict[Category, str]
store: dict[Target, Any]
@ -75,6 +95,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
registry: list[type["Platform"]]
client: AsyncClient
reverse_category: dict[str, Category]
use_batch: bool = False
@classmethod
@abstractmethod
@ -82,25 +103,18 @@ class Platform(metaclass=PlatformABCMeta, base=True):
...
@abstractmethod
async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]:
async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
...
async def do_fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[PlatformTarget, list[Post]]]:
try:
return await self.fetch_new_post(target, users)
except httpx.RequestError as err:
if plugin_config.bison_show_network_warning:
logger.warning(f"network connection error: {type(err)}, url: {err.request.url}")
return []
except ssl.SSLError as err:
if plugin_config.bison_show_network_warning:
logger.warning(f"ssl error: {err}")
return []
except json.JSONDecodeError as err:
logger.warning(f"json error, parsing: {err.doc}")
raise err
async def do_fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
return await catch_network_error(self.fetch_new_post, sub_unit) or []
@abstractmethod
async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
...
async def do_batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
return await catch_network_error(self.batch_fetch_new_post, sub_units) or []
@abstractmethod
async def parse(self, raw_post: RawPost) -> Post:
@ -190,10 +204,10 @@ class Platform(metaclass=PlatformABCMeta, base=True):
return res
async def dispatch_user_post(
self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]
self, new_posts: list[RawPost], sub_unit: SubUnit
) -> list[tuple[PlatformTarget, list[Post]]]:
res: list[tuple[PlatformTarget, list[Post]]] = []
for user, cats, required_tags in users:
for user, cats, required_tags in sub_unit.user_sub_infos:
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
user_post: list[Post] = []
for raw_post in user_raw_post:
@ -235,6 +249,12 @@ class MessageProcess(Platform, abstract=True):
@abstractmethod
async def get_sub_list(self, target: Target) -> list[RawPost]:
"Get post list of the given target"
raise NotImplementedError()
@abstractmethod
async def batch_get_sub_list(self, targets: list[Target]) -> list[list[RawPost]]:
"Get post list of the given targets"
raise NotImplementedError()
@abstractmethod
def get_date(self, post: RawPost) -> int | None:
@ -298,9 +318,12 @@ class NewMessage(MessageProcess, abstract=True):
self.set_stored_data(target, store)
return res
async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]:
post_list = await self.get_sub_list(target)
new_posts = await self.filter_common_with_diff(target, post_list)
async def _handle_new_post(
self,
post_list: list[RawPost],
sub_unit: SubUnit,
) -> list[tuple[PlatformTarget, list[Post]]]:
new_posts = await self.filter_common_with_diff(sub_unit.sub_target, post_list)
if not new_posts:
return []
else:
@ -308,14 +331,27 @@ class NewMessage(MessageProcess, abstract=True):
logger.info(
"fetch new post from {} {}: {}".format(
self.platform_name,
target if self.has_target else "-",
sub_unit.sub_target if self.has_target else "-",
self.get_id(post),
)
)
res = await self.dispatch_user_post(target, new_posts, users)
res = await self.dispatch_user_post(new_posts, sub_unit)
self.parse_cache = {}
return res
async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
post_list = await self.get_sub_list(sub_unit.sub_target)
return await self._handle_new_post(post_list, sub_unit)
async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
if not self.has_target:
raise RuntimeError("Target without target should not use batch api") # pragma: no cover
posts_set = await self.batch_get_sub_list([x[0] for x in sub_units])
res = []
for sub_unit, posts in zip(sub_units, posts_set):
res.extend(await self._handle_new_post(posts, sub_unit))
return res
class StatusChange(Platform, abstract=True):
"Watch a status, and fire a post when status changes"
@ -327,6 +363,10 @@ class StatusChange(Platform, abstract=True):
async def get_status(self, target: Target) -> Any:
...
@abstractmethod
async def batch_get_status(self, targets: list[Target]) -> list[Any]:
...
@abstractmethod
def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]:
...
@ -335,34 +375,51 @@ class StatusChange(Platform, abstract=True):
async def parse(self, raw_post: RawPost) -> Post:
...
async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]:
try:
new_status = await self.get_status(target)
except self.FetchError as err:
logger.warning(f"fetching {self.name}-{target} error: {err}")
raise
async def _handle_status_change(
self, new_status: Any, sub_unit: SubUnit
) -> list[tuple[PlatformTarget, list[Post]]]:
res = []
if old_status := self.get_stored_data(target):
diff = self.compare_status(target, old_status, new_status)
if old_status := self.get_stored_data(sub_unit.sub_target):
diff = self.compare_status(sub_unit.sub_target, old_status, new_status)
if diff:
logger.info(
"status changes {} {}: {} -> {}".format(
self.platform_name,
target if self.has_target else "-",
sub_unit.sub_target if self.has_target else "-",
old_status,
new_status,
)
)
res = await self.dispatch_user_post(target, diff, users)
self.set_stored_data(target, new_status)
res = await self.dispatch_user_post(diff, sub_unit)
self.set_stored_data(sub_unit.sub_target, new_status)
return res
async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
try:
new_status = await self.get_status(sub_unit.sub_target)
except self.FetchError as err:
logger.warning(f"fetching {self.name}-{sub_unit.sub_target} error: {err}")
raise
return await self._handle_status_change(new_status, sub_unit)
async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
if not self.has_target:
raise RuntimeError("Target without target should not use batch api") # pragma: no cover
new_statuses = await self.batch_get_status([x[0] for x in sub_units])
res = []
for sub_unit, new_status in zip(sub_units, new_statuses):
res.extend(await self._handle_status_change(new_status, sub_unit))
return res
class SimplePost(MessageProcess, abstract=True):
class SimplePost(NewMessage, abstract=True):
"Fetch a list of messages, dispatch it to different users"
async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]:
new_posts = await self.get_sub_list(target)
async def _handle_new_post(
self,
new_posts: list[RawPost],
sub_unit: SubUnit,
) -> list[tuple[PlatformTarget, list[Post]]]:
if not new_posts:
return []
else:
@ -370,11 +427,11 @@ class SimplePost(MessageProcess, abstract=True):
logger.info(
"fetch new post from {} {}: {}".format(
self.platform_name,
target if self.has_target else "-",
sub_unit.sub_target if self.has_target else "-",
self.get_id(post),
)
)
res = await self.dispatch_user_post(target, new_posts, users)
res = await self.dispatch_user_post(new_posts, sub_unit)
self.parse_cache = {}
return res
@ -422,10 +479,10 @@ def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]:
async def get_target_name(cls, client: AsyncClient, target: Target):
return await platform_list[0].get_target_name(client, target)
async def fetch_new_post(self: "NoTargetGroup", target: Target, users: list[UserSubInfo]):
async def fetch_new_post(self: "NoTargetGroup", sub_unit: SubUnit):
res = defaultdict(list)
for platform in self.platform_obj_list:
platform_res = await platform.fetch_new_post(target=target, users=users)
platform_res = await platform.fetch_new_post(sub_unit)
for user, posts in platform_res:
res[user].extend(posts)
return [[key, val] for key, val in res.items()]

View File

@ -29,7 +29,9 @@ async def init_scheduler():
for scheduler_config, target_list in _schedule_class_dict.items():
schedulable_args = []
for target in target_list:
schedulable_args.append((target.platform_name, T_Target(target.target)))
schedulable_args.append(
(target.platform_name, T_Target(target.target), platform_manager[target.platform_name].use_batch)
)
platform_name_list = _schedule_class_platform_dict[scheduler_config]
scheduler_dict[scheduler_config] = Scheduler(scheduler_config, schedulable_args, platform_name_list)
config.register_add_target_hook(handle_insert_new_target)

View File

@ -1,12 +1,13 @@
from dataclasses import dataclass
from collections import defaultdict
from nonebot.log import logger
from nonebot_plugin_apscheduler import scheduler
from nonebot_plugin_saa.utils.exceptions import NoBotFound
from ..types import Target
from ..config import config
from ..send import send_msgs
from ..types import Target, SubUnit
from ..platform import platform_manager
from ..utils import ProcessContext, SchedulerConfig
@ -16,15 +17,18 @@ class Schedulable:
platform_name: str
target: Target
current_weight: int
use_batch: bool = False
class Scheduler:
schedulable_list: list[Schedulable]
schedulable_list: list[Schedulable] # for load weigth from db
batch_api_target_cache: dict[str, dict[Target, list[Target]]] # platform_name -> (target -> [target])
batch_platform_name_targets_cache: dict[str, list[Target]]
def __init__(
self,
scheduler_config: type[SchedulerConfig],
schedulables: list[tuple[str, Target]],
schedulables: list[tuple[str, Target, bool]], # [(platform_name, target, use_batch)]
platform_name_list: list[str],
):
self.name = scheduler_config.name
@ -33,9 +37,17 @@ class Scheduler:
raise RuntimeError(f"{self.name} not found")
self.scheduler_config = scheduler_config
self.scheduler_config_obj = self.scheduler_config()
self.schedulable_list = []
for platform_name, target in schedulables:
self.schedulable_list.append(Schedulable(platform_name=platform_name, target=target, current_weight=0))
self.batch_platform_name_targets_cache: dict[str, list[Target]] = defaultdict(list)
for platform_name, target, use_batch in schedulables:
if use_batch:
self.batch_platform_name_targets_cache[platform_name].append(target)
self.schedulable_list.append(
Schedulable(platform_name=platform_name, target=target, current_weight=0, use_batch=use_batch)
)
self._refresh_batch_api_target_cache()
self.platform_name_list = platform_name_list
self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值
logger.info(
@ -48,6 +60,12 @@ class Scheduler:
**self.scheduler_config.schedule_setting,
)
def _refresh_batch_api_target_cache(self):
self.batch_api_target_cache = defaultdict(dict)
for platform_name, targets in self.batch_platform_name_targets_cache.items():
for target in targets:
self.batch_api_target_cache[platform_name][target] = targets
async def get_next_schedulable(self) -> Schedulable | None:
if not self.schedulable_list:
return None
@ -69,14 +87,24 @@ class Scheduler:
if not (schedulable := await self.get_next_schedulable()):
return
logger.trace(f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}")
send_userinfo_list = await config.get_platform_target_subscribers(schedulable.platform_name, schedulable.target)
client = await self.scheduler_config_obj.get_client(schedulable.target)
context.register_to_client(client)
try:
platform_obj = platform_manager[schedulable.platform_name](context, client)
to_send = await platform_obj.do_fetch_new_post(schedulable.target, send_userinfo_list)
if schedulable.use_batch:
batch_targets = self.batch_api_target_cache[schedulable.platform_name][schedulable.target]
sub_units = []
for batch_target in batch_targets:
userinfo = await config.get_platform_target_subscribers(schedulable.platform_name, batch_target)
sub_units.append(SubUnit(batch_target, userinfo))
to_send = await platform_obj.do_batch_fetch_new_post(sub_units)
else:
send_userinfo_list = await config.get_platform_target_subscribers(
schedulable.platform_name, schedulable.target
)
to_send = await platform_obj.do_fetch_new_post(SubUnit(schedulable.target, send_userinfo_list))
except Exception as err:
records = context.gen_req_records()
for record in records:
@ -101,9 +129,18 @@ class Scheduler:
def insert_new_schedulable(self, platform_name: str, target: Target):
self.pre_weight_val += 1000
self.schedulable_list.append(Schedulable(platform_name, target, 1000))
if platform_manager[platform_name].use_batch:
self.batch_platform_name_targets_cache[platform_name].append(target)
self._refresh_batch_api_target_cache()
logger.info(f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})")
def delete_schedulable(self, platform_name, target: Target):
if platform_manager[platform_name].use_batch:
self.batch_platform_name_targets_cache[platform_name].remove(target)
self._refresh_batch_api_target_cache()
if not self.schedulable_list:
return
to_find_idx = None
@ -114,4 +151,3 @@ class Scheduler:
if to_find_idx is not None:
deleted_schdulable = self.schedulable_list.pop(to_find_idx)
self.pre_weight_val -= deleted_schdulable.current_weight
return

View File

@ -53,3 +53,8 @@ class ApiError(Exception):
def __init__(self, url: URL) -> None:
msg = f"api {url} error"
super().__init__(msg)
class SubUnit(NamedTuple):
sub_target: Target
user_sub_infos: list[UserSubInfo]

View File

@ -49,6 +49,8 @@ async def test_fetch_new(
monster_siren_list_0,
monster_siren_list_1,
):
from nonebot_bison.types import Target, SubUnit
ak_list_router = respx.get("https://ak-webview.hypergryph.com/api/game/bulletinList?target=IOS")
detail_router = respx.get("https://ak-webview.hypergryph.com/api/game/bulletin/5716")
version_router = respx.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version")
@ -63,14 +65,14 @@ async def test_fetch_new(
preannouncement_router.mock(return_value=Response(200, json=get_json("arknights-pre-0.json")))
monster_siren_router.mock(return_value=Response(200, json=monster_siren_list_0))
terra_list.mock(return_value=Response(200, json=get_json("terra-hist-0.json")))
target = ""
res = await arknights.fetch_new_post(target, [dummy_user_subinfo])
target = Target("")
res = await arknights.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert ak_list_router.called
assert len(res) == 0
assert not detail_router.called
mock_data = arknights_list_0
ak_list_router.mock(return_value=Response(200, json=mock_data))
res3 = await arknights.fetch_new_post(target, [dummy_user_subinfo])
res3 = await arknights.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert len(res3[0][1]) == 1
assert detail_router.called
post = res3[0][1][0]
@ -82,7 +84,7 @@ async def test_fetch_new(
# assert(post.pics == ['https://ak-fs.hypergryph.com/announce/images/20210623/e6f49aeb9547a2278678368a43b95b07.jpg'])
await post.generate_messages()
terra_list.mock(return_value=Response(200, json=get_json("terra-hist-1.json")))
res = await arknights.fetch_new_post(target, [dummy_user_subinfo])
res = await arknights.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert len(res) == 1
post = res[0][1][0]
assert post.target_type == "terra-historicus"
@ -101,6 +103,8 @@ async def test_send_with_render(
monster_siren_list_0,
monster_siren_list_1,
):
from nonebot_bison.types import Target, SubUnit
ak_list_router = respx.get("https://ak-webview.hypergryph.com/api/game/bulletinList?target=IOS")
detail_router = respx.get("https://ak-webview.hypergryph.com/api/game/bulletin/8397")
version_router = respx.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version")
@ -115,14 +119,14 @@ async def test_send_with_render(
preannouncement_router.mock(return_value=Response(200, json=get_json("arknights-pre-0.json")))
monster_siren_router.mock(return_value=Response(200, json=monster_siren_list_0))
terra_list.mock(return_value=Response(200, json=get_json("terra-hist-0.json")))
target = ""
res = await arknights.fetch_new_post(target, [dummy_user_subinfo])
target = Target("")
res = await arknights.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert ak_list_router.called
assert len(res) == 0
assert not detail_router.called
mock_data = arknights_list_1
ak_list_router.mock(return_value=Response(200, json=mock_data))
res3 = await arknights.fetch_new_post(target, [dummy_user_subinfo])
res3 = await arknights.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert len(res3[0][1]) == 1
assert detail_router.called
post = res3[0][1][0]

View File

@ -106,14 +106,16 @@ async def test_dynamic_forward(bilibili, bing_dy_list):
@pytest.mark.asyncio
@respx.mock
async def test_fetch_new_without_dynamic(bilibili, dummy_user_subinfo, without_dynamic):
from nonebot_bison.types import Target, SubUnit
post_router = respx.get(
"https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history?host_uid=161775300&offset=0&need_top=0"
)
post_router.mock(return_value=Response(200, json=without_dynamic))
bilibili_main_page_router = respx.get("https://www.bilibili.com/")
bilibili_main_page_router.mock(return_value=Response(200))
target = "161775300"
res = await bilibili.fetch_new_post(target, [dummy_user_subinfo])
target = Target("161775300")
res = await bilibili.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert post_router.called
assert len(res) == 0
@ -121,21 +123,23 @@ async def test_fetch_new_without_dynamic(bilibili, dummy_user_subinfo, without_d
@pytest.mark.asyncio
@respx.mock
async def test_fetch_new(bilibili, dummy_user_subinfo):
from nonebot_bison.types import Target, SubUnit
post_router = respx.get(
"https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history?host_uid=161775300&offset=0&need_top=0"
)
post_router.mock(return_value=Response(200, json=get_json("bilibili_strange_post-0.json")))
bilibili_main_page_router = respx.get("https://www.bilibili.com/")
bilibili_main_page_router.mock(return_value=Response(200))
target = "161775300"
res = await bilibili.fetch_new_post(target, [dummy_user_subinfo])
target = Target("161775300")
res = await bilibili.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert post_router.called
assert len(res) == 0
mock_data = get_json("bilibili_strange_post.json")
mock_data["data"]["cards"][0]["desc"]["timestamp"] = int(datetime.now().timestamp())
post_router.mock(return_value=Response(200, json=mock_data))
res2 = await bilibili.fetch_new_post(target, [dummy_user_subinfo])
res2 = await bilibili.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert len(res2[0][1]) == 1
post = res2[0][1][0]
assert (

View File

@ -35,7 +35,7 @@ async def test_parse_target(bili_bangumi: "BilibiliBangumi"):
@pytest.mark.asyncio
@respx.mock
async def test_fetch_bilibili_bangumi_status(bili_bangumi: "BilibiliBangumi", dummy_user_subinfo):
from nonebot_bison.types import Target
from nonebot_bison.types import Target, SubUnit
bili_bangumi_router = respx.get("https://api.bilibili.com/pgc/review/user?media_id=28235413")
bili_bangumi_detail_router = respx.get("https://api.bilibili.com/pgc/view/web/season?season_id=39719")
@ -43,15 +43,15 @@ async def test_fetch_bilibili_bangumi_status(bili_bangumi: "BilibiliBangumi", du
bilibili_main_page_router = respx.get("https://www.bilibili.com/")
bilibili_main_page_router.mock(return_value=Response(200))
target = Target("28235413")
res = await bili_bangumi.fetch_new_post(target, [dummy_user_subinfo])
res = await bili_bangumi.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert len(res) == 0
res = await bili_bangumi.fetch_new_post(target, [dummy_user_subinfo])
res = await bili_bangumi.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert len(res) == 0
bili_bangumi_router.mock(return_value=Response(200, json=get_json("bilibili-gangumi-hanhua1.json")))
bili_bangumi_detail_router.mock(return_value=Response(200, json=get_json("bilibili-gangumi-hanhua1-detail.json")))
res2 = await bili_bangumi.fetch_new_post(target, [dummy_user_subinfo])
res2 = await bili_bangumi.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
post = res2[0][1][0]
assert post.target_type == "Bilibili剧集"

View File

@ -29,16 +29,18 @@ def dummy_only_open_user_subinfo(app: App):
@pytest.mark.asyncio
@respx.mock
async def test_fetch_bililive_no_room(bili_live, dummy_only_open_user_subinfo):
from nonebot_bison.types import Target, SubUnit
mock_bili_live_status = get_json("bili_live_status.json")
mock_bili_live_status["data"] = []
mock_bili_live_status["data"] = {}
bili_live_router = respx.get("https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids?uids[]=13164144")
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
bilibili_main_page_router = respx.get("https://www.bilibili.com/")
bilibili_main_page_router.mock(return_value=Response(200))
target = "13164144"
res = await bili_live.fetch_new_post(target, [dummy_only_open_user_subinfo])
target = Target("13164144")
res = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_open_user_subinfo]))])
assert bili_live_router.call_count == 1
assert len(res) == 0
@ -46,23 +48,25 @@ async def test_fetch_bililive_no_room(bili_live, dummy_only_open_user_subinfo):
@pytest.mark.asyncio
@respx.mock
async def test_fetch_first_live(bili_live, dummy_only_open_user_subinfo):
from nonebot_bison.types import Target, SubUnit
mock_bili_live_status = get_json("bili_live_status.json")
empty_bili_live_status = deepcopy(mock_bili_live_status)
empty_bili_live_status["data"] = []
empty_bili_live_status["data"] = {}
bili_live_router = respx.get("https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids?uids[]=13164144")
bili_live_router.mock(return_value=Response(200, json=empty_bili_live_status))
bilibili_main_page_router = respx.get("https://www.bilibili.com/")
bilibili_main_page_router.mock(return_value=Response(200))
target = "13164144"
res = await bili_live.fetch_new_post(target, [dummy_only_open_user_subinfo])
target = Target("13164144")
res = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_open_user_subinfo]))])
assert bili_live_router.call_count == 1
assert len(res) == 0
mock_bili_live_status["data"][target]["live_status"] = 1
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res2 = await bili_live.fetch_new_post(target, [dummy_only_open_user_subinfo])
res2 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_open_user_subinfo]))])
assert bili_live_router.call_count == 2
assert len(res2) == 1
post = res2[0][1][0]
@ -77,6 +81,8 @@ async def test_fetch_first_live(bili_live, dummy_only_open_user_subinfo):
@pytest.mark.asyncio
@respx.mock
async def test_fetch_bililive_only_live_open(bili_live, dummy_only_open_user_subinfo):
from nonebot_bison.types import Target, SubUnit
mock_bili_live_status = get_json("bili_live_status.json")
bili_live_router = respx.get("https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids?uids[]=13164144")
@ -85,14 +91,14 @@ async def test_fetch_bililive_only_live_open(bili_live, dummy_only_open_user_sub
bilibili_main_page_router = respx.get("https://www.bilibili.com/")
bilibili_main_page_router.mock(return_value=Response(200))
target = "13164144"
res = await bili_live.fetch_new_post(target, [dummy_only_open_user_subinfo])
target = Target("13164144")
res = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_open_user_subinfo]))])
assert bili_live_router.call_count == 1
assert len(res[0][1]) == 0
# 直播状态更新-上播
mock_bili_live_status["data"][target]["live_status"] = 1
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res2 = await bili_live.fetch_new_post(target, [dummy_only_open_user_subinfo])
res2 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_open_user_subinfo]))])
post = res2[0][1][0]
assert post.target_type == "Bilibili直播"
assert post.text == "[开播] 【Zc】从0挑战到15肉鸽目前10难度"
@ -103,13 +109,13 @@ async def test_fetch_bililive_only_live_open(bili_live, dummy_only_open_user_sub
# 标题变更
mock_bili_live_status["data"][target]["title"] = "【Zc】从0挑战到15肉鸽目前11难度"
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res3 = await bili_live.fetch_new_post(target, [dummy_only_open_user_subinfo])
res3 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_open_user_subinfo]))])
assert bili_live_router.call_count == 3
assert len(res3[0][1]) == 0
# 直播状态更新-下播
mock_bili_live_status["data"][target]["live_status"] = 0
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res4 = await bili_live.fetch_new_post(target, [dummy_only_open_user_subinfo])
res4 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_open_user_subinfo]))])
assert bili_live_router.call_count == 4
assert len(res4[0][1]) == 0
@ -127,8 +133,10 @@ def dummy_only_title_user_subinfo(app: App):
@pytest.mark.asyncio()
@respx.mock
async def test_fetch_bililive_only_title_change(bili_live, dummy_only_title_user_subinfo):
from nonebot_bison.types import Target, SubUnit
mock_bili_live_status = get_json("bili_live_status.json")
target = "13164144"
target = Target("13164144")
bili_live_router = respx.get("https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids?uids[]=13164144")
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
@ -136,25 +144,25 @@ async def test_fetch_bililive_only_title_change(bili_live, dummy_only_title_user
bilibili_main_page_router = respx.get("https://www.bilibili.com/")
bilibili_main_page_router.mock(return_value=Response(200))
res = await bili_live.fetch_new_post(target, [dummy_only_title_user_subinfo])
res = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_title_user_subinfo]))])
assert bili_live_router.call_count == 1
assert len(res) == 0
# 未开播前标题变更
mock_bili_live_status["data"][target]["title"] = "【Zc】从0挑战到15肉鸽目前11难度"
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res0 = await bili_live.fetch_new_post(target, [dummy_only_title_user_subinfo])
res0 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_title_user_subinfo]))])
assert bili_live_router.call_count == 2
assert len(res0) == 0
# 直播状态更新-上播
mock_bili_live_status["data"][target]["live_status"] = 1
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res2 = await bili_live.fetch_new_post(target, [dummy_only_title_user_subinfo])
res2 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_title_user_subinfo]))])
assert bili_live_router.call_count == 3
assert len(res2[0][1]) == 0
# 标题变更
mock_bili_live_status["data"][target]["title"] = "【Zc】从0挑战到15肉鸽目前12难度"
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res3 = await bili_live.fetch_new_post(target, [dummy_only_title_user_subinfo])
res3 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_title_user_subinfo]))])
post = res3[0][1][0]
assert post.target_type == "Bilibili直播"
assert post.text == "[标题更新] 【Zc】从0挑战到15肉鸽目前12难度"
@ -165,7 +173,7 @@ async def test_fetch_bililive_only_title_change(bili_live, dummy_only_title_user
# 直播状态更新-下播
mock_bili_live_status["data"][target]["live_status"] = 0
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res4 = await bili_live.fetch_new_post(target, [dummy_only_title_user_subinfo])
res4 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_title_user_subinfo]))])
assert bili_live_router.call_count == 5
assert len(res4[0][1]) == 0
@ -183,8 +191,10 @@ def dummy_only_close_user_subinfo(app: App):
@pytest.mark.asyncio
@respx.mock
async def test_fetch_bililive_only_close(bili_live, dummy_only_close_user_subinfo):
from nonebot_bison.types import Target, SubUnit
mock_bili_live_status = get_json("bili_live_status.json")
target = "13164144"
target = Target("13164144")
bili_live_router = respx.get("https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids?uids[]=13164144")
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
@ -192,31 +202,31 @@ async def test_fetch_bililive_only_close(bili_live, dummy_only_close_user_subinf
bilibili_main_page_router = respx.get("https://www.bilibili.com/")
bilibili_main_page_router.mock(return_value=Response(200))
res = await bili_live.fetch_new_post(target, [dummy_only_close_user_subinfo])
res = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_close_user_subinfo]))])
assert bili_live_router.call_count == 1
assert len(res) == 0
# 未开播前标题变更
mock_bili_live_status["data"][target]["title"] = "【Zc】从0挑战到15肉鸽目前11难度"
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res0 = await bili_live.fetch_new_post(target, [dummy_only_close_user_subinfo])
res0 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_close_user_subinfo]))])
assert bili_live_router.call_count == 2
assert len(res0) == 0
# 直播状态更新-上播
mock_bili_live_status["data"][target]["live_status"] = 1
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res2 = await bili_live.fetch_new_post(target, [dummy_only_close_user_subinfo])
res2 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_close_user_subinfo]))])
assert bili_live_router.call_count == 3
assert len(res2[0][1]) == 0
# 标题变更
mock_bili_live_status["data"][target]["title"] = "【Zc】从0挑战到15肉鸽目前12难度"
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res3 = await bili_live.fetch_new_post(target, [dummy_only_close_user_subinfo])
res3 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_close_user_subinfo]))])
assert bili_live_router.call_count == 4
assert len(res3[0][1]) == 0
# 直播状态更新-下播
mock_bili_live_status["data"][target]["live_status"] = 0
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res4 = await bili_live.fetch_new_post(target, [dummy_only_close_user_subinfo])
res4 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_only_close_user_subinfo]))])
assert bili_live_router.call_count == 5
post = res4[0][1][0]
assert post.target_type == "Bilibili直播"
@ -240,8 +250,10 @@ def dummy_bililive_user_subinfo(app: App):
@pytest.mark.asyncio
@respx.mock
async def test_fetch_bililive_combo(bili_live, dummy_bililive_user_subinfo):
from nonebot_bison.types import Target, SubUnit
mock_bili_live_status = get_json("bili_live_status.json")
target = "13164144"
target = Target("13164144")
bili_live_router = respx.get("https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids?uids[]=13164144")
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
@ -249,19 +261,19 @@ async def test_fetch_bililive_combo(bili_live, dummy_bililive_user_subinfo):
bilibili_main_page_router = respx.get("https://www.bilibili.com/")
bilibili_main_page_router.mock(return_value=Response(200))
res = await bili_live.fetch_new_post(target, [dummy_bililive_user_subinfo])
res = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_bililive_user_subinfo]))])
assert bili_live_router.call_count == 1
assert len(res) == 0
# 未开播前标题变更
mock_bili_live_status["data"][target]["title"] = "【Zc】从0挑战到15肉鸽目前11难度"
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res0 = await bili_live.fetch_new_post(target, [dummy_bililive_user_subinfo])
res0 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_bililive_user_subinfo]))])
assert bili_live_router.call_count == 2
assert len(res0) == 0
# 直播状态更新-上播
mock_bili_live_status["data"][target]["live_status"] = 1
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res2 = await bili_live.fetch_new_post(target, [dummy_bililive_user_subinfo])
res2 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_bililive_user_subinfo]))])
post2 = res2[0][1][0]
assert post2.target_type == "Bilibili直播"
assert post2.text == "[开播] 【Zc】从0挑战到15肉鸽目前11难度"
@ -272,7 +284,7 @@ async def test_fetch_bililive_combo(bili_live, dummy_bililive_user_subinfo):
# 标题变更
mock_bili_live_status["data"][target]["title"] = "【Zc】从0挑战到15肉鸽目前12难度"
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res3 = await bili_live.fetch_new_post(target, [dummy_bililive_user_subinfo])
res3 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_bililive_user_subinfo]))])
post3 = res3[0][1][0]
assert post3.target_type == "Bilibili直播"
assert post3.text == "[标题更新] 【Zc】从0挑战到15肉鸽目前12难度"
@ -283,7 +295,7 @@ async def test_fetch_bililive_combo(bili_live, dummy_bililive_user_subinfo):
# 直播状态更新-下播
mock_bili_live_status["data"][target]["live_status"] = 0
bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status))
res4 = await bili_live.fetch_new_post(target, [dummy_bililive_user_subinfo])
res4 = await bili_live.batch_fetch_new_post([(SubUnit(target, [dummy_bililive_user_subinfo]))])
post4 = res4[0][1][0]
assert post4.target_type == "Bilibili直播"
assert post4.text == "[下播] 【Zc】从0挑战到15肉鸽目前12难度"

View File

@ -27,16 +27,18 @@ def ff14_newdata_json_1():
@pytest.mark.asyncio
@respx.mock
async def test_fetch_new(ff14, dummy_user_subinfo, ff14_newdata_json_0, ff14_newdata_json_1):
from nonebot_bison.types import Target, SubUnit
newdata = respx.get(
"https://cqnews.web.sdo.com/api/news/newsList?gameCode=ff&CategoryCode=5309,5310,5311,5312,5313&pageIndex=0&pageSize=5"
)
newdata.mock(return_value=Response(200, json=ff14_newdata_json_0))
target = ""
res = await ff14.fetch_new_post(target, [dummy_user_subinfo])
target = Target("")
res = await ff14.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert newdata.called
assert len(res) == 0
newdata.mock(return_value=Response(200, json=ff14_newdata_json_1))
res = await ff14.fetch_new_post(target, [dummy_user_subinfo])
res = await ff14.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert newdata.called
post = res[0][1][0]
assert post.target_type == "ff14"

View File

@ -41,14 +41,16 @@ def ncm_artist_1(ncm_artist_raw: dict):
@pytest.mark.asyncio
@respx.mock
async def test_fetch_new(ncm_artist, ncm_artist_0, ncm_artist_1, dummy_user_subinfo):
from nonebot_bison.types import Target, SubUnit
ncm_router = respx.get("https://music.163.com/api/artist/albums/32540734")
ncm_router.mock(return_value=Response(200, json=ncm_artist_0))
target = "32540734"
res = await ncm_artist.fetch_new_post(target, [dummy_user_subinfo])
target = Target("32540734")
res = await ncm_artist.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert ncm_router.called
assert len(res) == 0
ncm_router.mock(return_value=Response(200, json=ncm_artist_1))
res2 = await ncm_artist.fetch_new_post(target, [dummy_user_subinfo])
res2 = await ncm_artist.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
post = res2[0][1][0]
assert post.target_type == "ncm-artist"
assert post.text == "新专辑发布Y1K"

View File

@ -41,14 +41,16 @@ def ncm_radio_1(ncm_radio_raw: dict):
@pytest.mark.asyncio
@respx.mock
async def test_fetch_new(ncm_radio, ncm_radio_0, ncm_radio_1, dummy_user_subinfo):
from nonebot_bison.types import Target, SubUnit
ncm_router = respx.post("http://music.163.com/api/dj/program/byradio")
ncm_router.mock(return_value=Response(200, json=ncm_radio_0))
target = "793745436"
res = await ncm_radio.fetch_new_post(target, [dummy_user_subinfo])
target = Target("793745436")
res = await ncm_radio.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert ncm_router.called
assert len(res) == 0
ncm_router.mock(return_value=Response(200, json=ncm_radio_1))
res2 = await ncm_radio.fetch_new_post(target, [dummy_user_subinfo])
res2 = await ncm_radio.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
post = res2[0][1][0]
assert post.target_type == "ncm-radio"
assert post.text == "网易云电台更新:「松烟行动」灰齐山麓"

View File

@ -325,16 +325,14 @@ def mock_status_change(app: App):
@pytest.mark.asyncio
async def test_new_message_target_without_cats_tags(mock_platform_without_cats_tags, user_info_factory):
from nonebot_bison.utils import ProcessContext
from nonebot_bison.types import Target, SubUnit
res1 = await mock_platform_without_cats_tags(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])]
SubUnit(Target("dummy"), [user_info_factory([1, 2], [])])
)
assert len(res1) == 0
res2 = await mock_platform_without_cats_tags(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy",
[
user_info_factory([], []),
],
SubUnit(Target("dummy"), [user_info_factory([], [])]),
)
assert len(res2) == 1
posts_1 = res2[0][1]
@ -348,16 +346,21 @@ async def test_new_message_target_without_cats_tags(mock_platform_without_cats_t
@pytest.mark.asyncio
async def test_new_message_target(mock_platform, user_info_factory):
from nonebot_bison.utils import ProcessContext
from nonebot_bison.types import Target, SubUnit
res1 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post("dummy", [user_info_factory([1, 2], [])])
res1 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post(
SubUnit(Target("dummy"), [user_info_factory([1, 2], [])])
)
assert len(res1) == 0
res2 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy",
[
user_info_factory([1, 2], []),
user_info_factory([1], []),
user_info_factory([1, 2], ["tag1"]),
],
SubUnit(
Target("dummy"),
[
user_info_factory([1, 2], []),
user_info_factory([1], []),
user_info_factory([1, 2], ["tag1"]),
],
),
)
assert len(res2) == 3
posts_1 = res2[0][1]
@ -378,18 +381,21 @@ async def test_new_message_target(mock_platform, user_info_factory):
@pytest.mark.asyncio
async def test_new_message_no_target(mock_platform_no_target, user_info_factory):
from nonebot_bison.utils import ProcessContext
from nonebot_bison.types import Target, SubUnit
res1 = await mock_platform_no_target(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])]
SubUnit(Target("dummy"), [user_info_factory([1, 2], [])])
)
assert len(res1) == 0
res2 = await mock_platform_no_target(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy",
[
user_info_factory([1, 2], []),
user_info_factory([1], []),
user_info_factory([1, 2], ["tag1"]),
],
SubUnit(
Target("dummy"),
[
user_info_factory([1, 2], []),
user_info_factory([1], []),
user_info_factory([1, 2], ["tag1"]),
],
),
)
assert len(res2) == 3
posts_1 = res2[0][1]
@ -406,7 +412,7 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory)
assert "p2" in id_set_2
assert "p2" in id_set_3
res3 = await mock_platform_no_target(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])]
SubUnit(Target("dummy"), [user_info_factory([1, 2], [])])
)
assert len(res3) == 0
@ -414,31 +420,34 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory)
@pytest.mark.asyncio
async def test_status_change(mock_status_change, user_info_factory):
from nonebot_bison.utils import ProcessContext
from nonebot_bison.types import Target, SubUnit
res1 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])]
SubUnit(Target("dummy"), [user_info_factory([1, 2], [])])
)
assert len(res1) == 0
res2 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])]
SubUnit(Target("dummy"), [user_info_factory([1, 2], [])])
)
assert len(res2) == 1
posts = res2[0][1]
assert len(posts) == 1
assert posts[0].text == "on"
res3 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy",
[
user_info_factory([1, 2], []),
user_info_factory([1], []),
],
SubUnit(
Target("dummy"),
[
user_info_factory([1, 2], []),
user_info_factory([1], []),
],
),
)
assert len(res3) == 2
assert len(res3[0][1]) == 1
assert res3[0][1][0].text == "off"
assert len(res3[1][1]) == 0
res4 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])]
SubUnit(Target("dummy"), [user_info_factory([1, 2], [])])
)
assert len(res4) == 0
@ -450,7 +459,7 @@ async def test_group(
mock_platform_no_target_2,
user_info_factory,
):
from nonebot_bison.types import Target
from nonebot_bison.types import Target, SubUnit
from nonebot_bison.utils import ProcessContext, http_client
from nonebot_bison.platform.platform import make_no_target_group
@ -458,14 +467,176 @@ async def test_group(
group_platform_class = make_no_target_group([mock_platform_no_target, mock_platform_no_target_2])
group_platform = group_platform_class(ProcessContext(), http_client())
res1 = await group_platform.fetch_new_post(dummy, [user_info_factory([1, 4], [])])
res1 = await group_platform.fetch_new_post(SubUnit(dummy, [user_info_factory([1, 4], [])]))
assert len(res1) == 0
res2 = await group_platform.fetch_new_post(dummy, [user_info_factory([1, 4], [])])
res2 = await group_platform.fetch_new_post(SubUnit(dummy, [user_info_factory([1, 4], [])]))
assert len(res2) == 1
posts = res2[0][1]
assert len(posts) == 2
id_set_2 = {x.text for x in posts}
assert "p2" in id_set_2
assert "p6" in id_set_2
res3 = await group_platform.fetch_new_post(dummy, [user_info_factory([1, 4], [])])
res3 = await group_platform.fetch_new_post(SubUnit(dummy, [user_info_factory([1, 4], [])]))
assert len(res3) == 0
async def test_batch_fetch_new_message(app: App):
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.post import Post
from nonebot_bison.platform.platform import NewMessage
from nonebot_bison.utils.context import ProcessContext
from nonebot_bison.types import Target, RawPost, SubUnit, UserSubInfo
class BatchNewMessage(NewMessage):
platform_name = "mock_platform"
name = "Mock Platform"
enabled = True
is_common = True
schedule_interval = 10
enable_tag = False
categories = {}
has_target = True
sub_index = 0
@classmethod
async def get_target_name(cls, client, _: "Target"):
return "MockPlatform"
def get_id(self, post: "RawPost") -> Any:
return post["id"]
def get_date(self, raw_post: "RawPost") -> float:
return raw_post["date"]
async def parse(self, raw_post: "RawPost") -> "Post":
return Post(
"mock_platform",
raw_post["text"],
"http://t.tt/" + str(self.get_id(raw_post)),
target_name="Mock",
)
@classmethod
async def batch_get_sub_list(cls, targets: list[Target]) -> list[list[RawPost]]:
assert len(targets) > 1
if cls.sub_index == 0:
cls.sub_index += 1
res = [
[raw_post_list_2[0]],
[raw_post_list_2[1]],
]
else:
res = [
[raw_post_list_2[0], raw_post_list_2[2]],
[raw_post_list_2[1], raw_post_list_2[3]],
]
res += [[]] * (len(targets) - 2)
return res
user1 = UserSubInfo(TargetQQGroup(group_id=123), [1, 2, 3], [])
user2 = UserSubInfo(TargetQQGroup(group_id=234), [1, 2, 3], [])
platform_obj = BatchNewMessage(ProcessContext(), None) # type:ignore
res1 = await platform_obj.batch_fetch_new_post(
[
SubUnit(Target("target1"), [user1]),
SubUnit(Target("target2"), [user1, user2]),
SubUnit(Target("target3"), [user2]),
]
)
assert len(res1) == 0
res2 = await platform_obj.batch_fetch_new_post(
[
SubUnit(Target("target1"), [user1]),
SubUnit(Target("target2"), [user1, user2]),
SubUnit(Target("target3"), [user2]),
]
)
assert len(res2) == 3
send_set = set()
for platform_target, posts in res2:
for post in posts:
send_set.add((platform_target, post.text))
assert (TargetQQGroup(group_id=123), "p3") in send_set
assert (TargetQQGroup(group_id=123), "p4") in send_set
assert (TargetQQGroup(group_id=234), "p4") in send_set
async def test_batch_fetch_compare_status(app: App):
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.post import Post
from nonebot_bison.utils.context import ProcessContext
from nonebot_bison.platform.platform import StatusChange
from nonebot_bison.types import Target, RawPost, SubUnit, Category, UserSubInfo
class BatchStatusChange(StatusChange):
platform_name = "mock_platform"
name = "Mock Platform"
enabled = True
is_common = True
enable_tag = False
schedule_type = "interval"
schedule_kw = {"seconds": 10}
has_target = False
categories = {
Category(1): "转发",
Category(2): "视频",
}
sub_index = 0
@classmethod
async def batch_get_status(cls, targets: "list[Target]"):
assert len(targets) > 0
res = [{"s": cls.sub_index == 1} for _ in targets]
res[0]["s"] = not res[0]["s"]
if cls.sub_index == 0:
cls.sub_index += 1
return res
def compare_status(self, target, old_status, new_status) -> list["RawPost"]:
if old_status["s"] is False and new_status["s"] is True:
return [{"text": "on", "cat": 1}]
elif old_status["s"] is True and new_status["s"] is False:
return [{"text": "off", "cat": 2}]
return []
async def parse(self, raw_post) -> "Post":
return Post("mock_status", raw_post["text"], "")
def get_category(self, raw_post):
return raw_post["cat"]
batch_status_change = BatchStatusChange(ProcessContext(), None) # type: ignore
user1 = UserSubInfo(TargetQQGroup(group_id=123), [1, 2, 3], [])
user2 = UserSubInfo(TargetQQGroup(group_id=234), [1, 2, 3], [])
res1 = await batch_status_change.batch_fetch_new_post(
[
SubUnit(Target("target1"), [user1]),
SubUnit(Target("target2"), [user1, user2]),
]
)
assert len(res1) == 0
res2 = await batch_status_change.batch_fetch_new_post(
[
SubUnit(Target("target1"), [user1]),
SubUnit(Target("target2"), [user1, user2]),
]
)
send_set = set()
for platform_target, posts in res2:
for post in posts:
send_set.add((platform_target, post.text))
assert len(send_set) == 3
assert (TargetQQGroup(group_id=123), "off") in send_set
assert (TargetQQGroup(group_id=123), "on") in send_set
assert (TargetQQGroup(group_id=234), "on") in send_set

View File

@ -68,15 +68,17 @@ async def test_fetch_new_1(
user_info_factory,
update_time_feed_1,
):
from nonebot_bison.types import Target, SubUnit
## 标题重复的情况
rss_router = respx.get("https://rsshub.app/twitter/user/ArknightsStaff")
rss_router.mock(return_value=Response(200, text=get_file("rss-twitter-ArknightsStaff-0.xml")))
target = "https://rsshub.app/twitter/user/ArknightsStaff"
res1 = await rss.fetch_new_post(target, [user_info_factory([], [])])
target = Target("https://rsshub.app/twitter/user/ArknightsStaff")
res1 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])]))
assert len(res1) == 0
rss_router.mock(return_value=Response(200, text=update_time_feed_1))
res2 = await rss.fetch_new_post(target, [user_info_factory([], [])])
res2 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])]))
assert len(res2[0][1]) == 1
post1 = res2[0][1][0]
assert post1.url == "https://twitter.com/ArknightsStaff/status/1659091539023282178"
@ -95,15 +97,17 @@ async def test_fetch_new_2(
user_info_factory,
update_time_feed_2,
):
from nonebot_bison.types import Target, SubUnit
## 标题与正文不重复的情况
rss_router = respx.get("https://www.ruanyifeng.com/blog/atom.xml")
rss_router.mock(return_value=Response(200, text=get_file("rss-ruanyifeng-0.xml")))
target = "https://www.ruanyifeng.com/blog/atom.xml"
res1 = await rss.fetch_new_post(target, [user_info_factory([], [])])
target = Target("https://www.ruanyifeng.com/blog/atom.xml")
res1 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])]))
assert len(res1) == 0
rss_router.mock(return_value=Response(200, text=update_time_feed_2))
res2 = await rss.fetch_new_post(target, [user_info_factory([], [])])
res2 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])]))
assert len(res2[0][1]) == 1
post1 = res2[0][1][0]
assert post1.url == "http://www.ruanyifeng.com/blog/2023/05/weekly-issue-255.html"
@ -129,15 +133,17 @@ async def test_fetch_new_3(
user_info_factory,
update_time_feed_3,
):
from nonebot_bison.types import Target, SubUnit
## 只有<updated>没有<published>
rss_router = respx.get("https://github.com/R3nzTheCodeGOD/R3nzSkin/releases.atom")
rss_router.mock(return_value=Response(200, text=get_file("rss-github-atom-0.xml")))
target = "https://github.com/R3nzTheCodeGOD/R3nzSkin/releases.atom"
res1 = await rss.fetch_new_post(target, [user_info_factory([], [])])
target = Target("https://github.com/R3nzTheCodeGOD/R3nzSkin/releases.atom")
res1 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])]))
assert len(res1) == 0
rss_router.mock(return_value=Response(200, text=update_time_feed_3))
res2 = await rss.fetch_new_post(target, [user_info_factory([], [])])
res2 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])]))
assert len(res2[0][1]) == 1
post1 = res2[0][1][0]
assert post1.url == "https://github.com/R3nzTheCodeGOD/R3nzSkin/releases/tag/v3.0.9"
@ -150,15 +156,17 @@ async def test_fetch_new_4(
rss,
user_info_factory,
):
from nonebot_bison.types import Target, SubUnit
## 没有日期信息的情况
rss_router = respx.get("https://rsshub.app/wallhaven/hot?limit=5")
rss_router.mock(return_value=Response(200, text=get_file("rss-top5-old.xml")))
target = "https://rsshub.app/wallhaven/hot?limit=5"
res1 = await rss.fetch_new_post(target, [user_info_factory([], [])])
target = Target("https://rsshub.app/wallhaven/hot?limit=5")
res1 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])]))
assert len(res1) == 0
rss_router.mock(return_value=Response(200, text=get_file("rss-top5-new.xml")))
res2 = await rss.fetch_new_post(target, [user_info_factory([], [])])
res2 = await rss.fetch_new_post(SubUnit(target, [user_info_factory([], [])]))
assert len(res2[0][1]) == 1
post1 = res2[0][1][0]
assert post1.url == "https://wallhaven.cc/w/85rjej"

View File

@ -41,25 +41,27 @@ async def test_get_name(weibo):
@pytest.mark.asyncio
@respx.mock
async def test_fetch_new(weibo, dummy_user_subinfo):
from nonebot_bison.types import Target, SubUnit
ak_list_router = respx.get("https://m.weibo.cn/api/container/getIndex?containerid=1076036279793937")
detail_router = respx.get("https://m.weibo.cn/detail/4649031014551911")
ak_list_router.mock(return_value=Response(200, json=get_json("weibo_ak_list_0.json")))
detail_router.mock(return_value=Response(200, text=get_file("weibo_detail_4649031014551911")))
image_cdn_router.mock(Response(200, content=b""))
target = "6279793937"
res = await weibo.fetch_new_post(target, [dummy_user_subinfo])
target = Target("6279793937")
res = await weibo.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert ak_list_router.called
assert len(res) == 0
assert not detail_router.called
mock_data = get_json("weibo_ak_list_1.json")
ak_list_router.mock(return_value=Response(200, json=mock_data))
res2 = await weibo.fetch_new_post(target, [dummy_user_subinfo])
res2 = await weibo.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert len(res2) == 0
mock_data["data"]["cards"][1]["mblog"]["created_at"] = datetime.now(timezone("Asia/Shanghai")).strftime(
"%a %b %d %H:%M:%S %z %Y"
)
ak_list_router.mock(return_value=Response(200, json=mock_data))
res3 = await weibo.fetch_new_post(target, [dummy_user_subinfo])
res3 = await weibo.fetch_new_post(SubUnit(target, [dummy_user_subinfo]))
assert len(res3[0][1]) == 1
assert not detail_router.called
post = res3[0][1][0]
@ -103,7 +105,9 @@ def test_tag(weibo, weibo_ak_list_1):
@pytest.mark.asyncio
@pytest.mark.compare
async def test_rsshub_compare(weibo):
target = "6279793937"
from nonebot_bison.types import Target
target = Target("6279793937")
raw_posts = filter(weibo.filter_platform_custom, await weibo.get_sub_list(target))
posts = []
for raw_post in raw_posts:

View File

@ -1,7 +1,9 @@
import typing
from datetime import time
from unittest.mock import AsyncMock
from nonebug import App
from httpx import AsyncClient
from pytest_mock import MockerFixture
if typing.TYPE_CHECKING:
@ -50,6 +52,45 @@ async def test_scheduler_without_time(init_scheduler):
assert static_res["bilibili-live-t2"] == 3
async def test_scheduler_batch_api(init_scheduler, mocker: MockerFixture):
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config
from nonebot_bison.types import UserSubInfo
from nonebot_bison.scheduler import scheduler_dict
from nonebot_bison.types import Target as T_Target
from nonebot_bison.scheduler.manager import init_scheduler
from nonebot_bison.platform.bilibili import BilibiliSchedConf
await config.add_subscribe(TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili-live", [], [])
await config.add_subscribe(TargetQQGroup(group_id=123), T_Target("t2"), "target2", "bilibili-live", [], [])
mocker.patch.object(BilibiliSchedConf, "get_client", return_value=AsyncClient())
await init_scheduler()
batch_fetch_mock = AsyncMock()
class FakePlatform:
def __init__(self) -> None:
self.do_batch_fetch_new_post = batch_fetch_mock
fake_platform_obj = FakePlatform()
mocker.patch.dict(
"nonebot_bison.scheduler.scheduler.platform_manager",
{"bilibili-live": mocker.Mock(return_value=fake_platform_obj)},
)
await scheduler_dict[BilibiliSchedConf].exec_fetch()
batch_fetch_mock.assert_called_once_with(
[
(T_Target("t1"), [UserSubInfo(user=TargetQQGroup(group_id=123), categories=[], tags=[])]),
(T_Target("t2"), [UserSubInfo(user=TargetQQGroup(group_id=123), categories=[], tags=[])]),
]
)
async def test_scheduler_with_time(app: App, init_scheduler, mocker: MockerFixture):
from nonebot_plugin_saa import TargetQQGroup