From bd679914eba25c4dc8dbc27eedd9201d23d8ceb0 Mon Sep 17 00:00:00 2001 From: felinae98 <731499577@qq.com> Date: Thu, 24 Nov 2022 13:12:56 +0800 Subject: [PATCH] :sparkles: add context to log http error --- .../nonebot_bison/platform/bilibili.py | 16 ++-- src/plugins/nonebot_bison/platform/ff14.py | 11 ++- src/plugins/nonebot_bison/platform/ncm.py | 80 +++++++++---------- .../nonebot_bison/platform/platform.py | 23 +++--- src/plugins/nonebot_bison/platform/rss.py | 22 +++-- src/plugins/nonebot_bison/platform/weibo.py | 2 +- .../nonebot_bison/scheduler/scheduler.py | 25 ++++-- src/plugins/nonebot_bison/types.py | 7 ++ src/plugins/nonebot_bison/utils/__init__.py | 2 + src/plugins/nonebot_bison/utils/context.py | 40 ++++++++++ tests/platforms/test_arknights.py | 3 +- tests/platforms/test_bilibili.py | 3 +- tests/platforms/test_bilibili_bangumi.py | 3 +- tests/platforms/test_bilibili_live.py | 3 +- tests/platforms/test_ff14.py | 3 +- tests/platforms/test_mcbbsnews.py | 3 +- tests/platforms/test_ncm_artist.py | 3 +- tests/platforms/test_ncm_radio.py | 3 +- tests/platforms/test_platform.py | 69 ++++++++-------- tests/platforms/test_platform_tag_filter.py | 6 +- tests/platforms/test_weibo.py | 3 +- tests/test_context.py | 20 +++++ 22 files changed, 218 insertions(+), 132 deletions(-) create mode 100644 src/plugins/nonebot_bison/utils/context.py create mode 100644 tests/test_context.py diff --git a/src/plugins/nonebot_bison/platform/bilibili.py b/src/plugins/nonebot_bison/platform/bilibili.py index a1144fc..7bdb1d7 100644 --- a/src/plugins/nonebot_bison/platform/bilibili.py +++ b/src/plugins/nonebot_bison/platform/bilibili.py @@ -1,21 +1,17 @@ -import functools import json import re from copy import deepcopy from dataclasses import dataclass, field from datetime import datetime, timedelta -from typing import Any, Callable, Optional +from typing import Any, Optional -import httpx from httpx import AsyncClient from nonebot.log import logger from typing_extensions import Self -from ..plugin_config import plugin_config from ..post import Post -from ..types import Category, RawPost, Tag, Target +from ..types import ApiError, Category, RawPost, Tag, Target from ..utils import SchedulerConfig -from ..utils.http import http_args from .platform import CategoryNotSupport, NewMessage, StatusChange @@ -105,7 +101,7 @@ class Bilibili(NewMessage): if res_dict["code"] == 0: return res_dict["data"].get("cards") else: - return [] + raise ApiError(res.request.url) def get_id(self, post: RawPost) -> Any: return post["desc"]["dynamic_id"] @@ -306,7 +302,7 @@ class Bilibililive(StatusChange): self.name, text=title, url=url, - pics=pic, + pics=list(pic), target_name=target_name, compress=True, ) @@ -384,14 +380,14 @@ class BilibiliBangumi(StatusChange): lastest_episode = detail_dict["result"]["episodes"] url = lastest_episode["link"] - pic = [lastest_episode["cover"]] + pic: list[str] = [lastest_episode["cover"]] target_name = detail_dict["result"]["season_title"] text = lastest_episode["share_copy"] return Post( self.name, text=text, url=url, - pics=pic, + pics=list(pic), target_name=target_name, compress=True, ) diff --git a/src/plugins/nonebot_bison/platform/ff14.py b/src/plugins/nonebot_bison/platform/ff14.py index a8dfe40..f549046 100644 --- a/src/plugins/nonebot_bison/platform/ff14.py +++ b/src/plugins/nonebot_bison/platform/ff14.py @@ -4,7 +4,7 @@ from httpx import AsyncClient from ..post import Post from ..types import RawPost, Target -from ..utils import http_client, scheduler +from ..utils import scheduler from .platform import NewMessage @@ -27,11 +27,10 @@ class FF14(NewMessage): return "最终幻想XIV官方公告" async def get_sub_list(self, _) -> list[RawPost]: - async with http_client() as client: - raw_data = await client.get( - "https://ff.web.sdo.com/inc/newdata.ashx?url=List?gameCode=ff&category=5309,5310,5311,5312,5313&pageIndex=0&pageSize=5" - ) - return raw_data.json()["Data"] + raw_data = await self.client.get( + "https://ff.web.sdo.com/inc/newdata.ashx?url=List?gameCode=ff&category=5309,5310,5311,5312,5313&pageIndex=0&pageSize=5" + ) + return raw_data.json()["Data"] def get_id(self, post: RawPost) -> Any: """用发布时间当作 ID diff --git a/src/plugins/nonebot_bison/platform/ncm.py b/src/plugins/nonebot_bison/platform/ncm.py index 68e4189..4688f22 100644 --- a/src/plugins/nonebot_bison/platform/ncm.py +++ b/src/plugins/nonebot_bison/platform/ncm.py @@ -4,8 +4,8 @@ from typing import Any, Optional from httpx import AsyncClient from ..post import Post -from ..types import RawPost, Target -from ..utils import SchedulerConfig, http_client +from ..types import ApiError, RawPost, Target +from ..utils import SchedulerConfig from .platform import NewMessage @@ -32,15 +32,14 @@ class NcmArtist(NewMessage): async def get_target_name( cls, client: AsyncClient, target: Target ) -> Optional[str]: - async with http_client() as client: - res = await client.get( - "https://music.163.com/api/artist/albums/{}".format(target), - headers={"Referer": "https://music.163.com/"}, - ) - res_data = res.json() - if res_data["code"] != 200: - return - return res_data["artist"]["name"] + res = await client.get( + "https://music.163.com/api/artist/albums/{}".format(target), + headers={"Referer": "https://music.163.com/"}, + ) + res_data = res.json() + if res_data["code"] != 200: + raise ApiError(res.request.url) + return res_data["artist"]["name"] @classmethod async def parse_target(cls, target_text: str) -> Target: @@ -54,16 +53,15 @@ class NcmArtist(NewMessage): raise cls.ParseTargetException() async def get_sub_list(self, target: Target) -> list[RawPost]: - async with http_client() as client: - res = await client.get( - "https://music.163.com/api/artist/albums/{}".format(target), - headers={"Referer": "https://music.163.com/"}, - ) - res_data = res.json() - if res_data["code"] != 200: - return [] - else: - return res_data["hotAlbums"] + res = await self.client.get( + "https://music.163.com/api/artist/albums/{}".format(target), + headers={"Referer": "https://music.163.com/"}, + ) + res_data = res.json() + if res_data["code"] != 200: + return [] + else: + return res_data["hotAlbums"] def get_id(self, post: RawPost) -> Any: return post["id"] @@ -97,16 +95,15 @@ class NcmRadio(NewMessage): async def get_target_name( cls, client: AsyncClient, target: Target ) -> Optional[str]: - async with http_client() as client: - res = await client.post( - "http://music.163.com/api/dj/program/byradio", - headers={"Referer": "https://music.163.com/"}, - data={"radioId": target, "limit": 1000, "offset": 0}, - ) - res_data = res.json() - if res_data["code"] != 200 or res_data["programs"] == 0: - return - return res_data["programs"][0]["radio"]["name"] + res = await client.post( + "http://music.163.com/api/dj/program/byradio", + headers={"Referer": "https://music.163.com/"}, + data={"radioId": target, "limit": 1000, "offset": 0}, + ) + res_data = res.json() + if res_data["code"] != 200 or res_data["programs"] == 0: + return + return res_data["programs"][0]["radio"]["name"] @classmethod async def parse_target(cls, target_text: str) -> Target: @@ -120,17 +117,16 @@ class NcmRadio(NewMessage): raise cls.ParseTargetException() async def get_sub_list(self, target: Target) -> list[RawPost]: - async with http_client() as client: - res = await client.post( - "http://music.163.com/api/dj/program/byradio", - headers={"Referer": "https://music.163.com/"}, - data={"radioId": target, "limit": 1000, "offset": 0}, - ) - res_data = res.json() - if res_data["code"] != 200: - return [] - else: - return res_data["programs"] + res = await self.client.post( + "http://music.163.com/api/dj/program/byradio", + headers={"Referer": "https://music.163.com/"}, + data={"radioId": target, "limit": 1000, "offset": 0}, + ) + res_data = res.json() + if res_data["code"] != 200: + return [] + else: + return res_data["programs"] def get_id(self, post: RawPost) -> Any: return post["id"] diff --git a/src/plugins/nonebot_bison/platform/platform.py b/src/plugins/nonebot_bison/platform/platform.py index eb2f759..7077685 100644 --- a/src/plugins/nonebot_bison/platform/platform.py +++ b/src/plugins/nonebot_bison/platform/platform.py @@ -14,7 +14,7 @@ from nonebot.log import logger from ..plugin_config import plugin_config from ..post import Post from ..types import Category, RawPost, Tag, Target, User, UserSubInfo -from ..utils.scheduler_config import SchedulerConfig +from ..utils import ProcessContext, SchedulerConfig class CategoryNotSupport(Exception): @@ -57,6 +57,7 @@ class PlatformABCMeta(PlatformMeta, ABC): class Platform(metaclass=PlatformABCMeta, base=True): scheduler: Type[SchedulerConfig] + ctx: ProcessContext is_common: bool enabled: bool name: str @@ -99,7 +100,7 @@ class Platform(metaclass=PlatformABCMeta, base=True): return [] except json.JSONDecodeError as err: logger.warning(f"json error, parsing: {err.doc}") - return [] + raise err @abstractmethod async def parse(self, raw_post: RawPost) -> Post: @@ -109,9 +110,10 @@ class Platform(metaclass=PlatformABCMeta, base=True): "actually function called" return await self.parse(raw_post) - def __init__(self, client: AsyncClient): + def __init__(self, context: ProcessContext, client: AsyncClient): super().__init__() self.client = client + self.ctx = context class ParseTargetException(Exception): pass @@ -209,8 +211,8 @@ class Platform(metaclass=PlatformABCMeta, base=True): class MessageProcess(Platform, abstract=True): "General message process fetch, parse, filter progress" - def __init__(self, client: AsyncClient): - super().__init__(client) + def __init__(self, ctx: ProcessContext, client: AsyncClient): + super().__init__(ctx, client) self.parse_cache: dict[Any, Post] = dict() @abstractmethod @@ -254,6 +256,9 @@ class MessageProcess(Platform, abstract=True): try: self.get_category(raw_post) except CategoryNotSupport: + msgs = self.ctx.gen_req_records() + for m in msgs: + logger.warning(m) continue except NotImplementedError: pass @@ -342,7 +347,7 @@ class StatusChange(Platform, abstract=True): new_status = await self.get_status(target) except self.FetchError as err: logger.warning(f"fetching {self.name}-{target} error: {err}") - return [] + raise res = [] if old_status := self.get_stored_data(target): diff = self.compare_status(target, old_status, new_status) @@ -420,11 +425,11 @@ def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]: "Platform scheduler for {} not fit".format(platform_name) ) - def __init__(self: "NoTargetGroup", client: AsyncClient): - Platform.__init__(self, client) + def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient): + Platform.__init__(self, ctx, client) self.platform_obj_list = [] for platform_class in self.platform_list: - self.platform_obj_list.append(platform_class(client)) + self.platform_obj_list.append(platform_class(ctx, client)) def __str__(self: "NoTargetGroup") -> str: return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]" diff --git a/src/plugins/nonebot_bison/platform/rss.py b/src/plugins/nonebot_bison/platform/rss.py index 1ab76f3..eeccd7c 100644 --- a/src/plugins/nonebot_bison/platform/rss.py +++ b/src/plugins/nonebot_bison/platform/rss.py @@ -7,7 +7,7 @@ from httpx import AsyncClient from ..post import Post from ..types import RawPost, Target -from ..utils import http_client, scheduler +from ..utils import scheduler from .platform import NewMessage @@ -26,10 +26,9 @@ class Rss(NewMessage): async def get_target_name( cls, client: AsyncClient, target: Target ) -> Optional[str]: - async with http_client() as client: - res = await client.get(target, timeout=10.0) - feed = feedparser.parse(res.text) - return feed["feed"]["title"] + res = await client.get(target, timeout=10.0) + feed = feedparser.parse(res.text) + return feed["feed"]["title"] def get_date(self, post: RawPost) -> int: return calendar.timegm(post.published_parsed) @@ -38,13 +37,12 @@ class Rss(NewMessage): return post.id async def get_sub_list(self, target: Target) -> list[RawPost]: - async with http_client() as client: - res = await client.get(target, timeout=10.0) - feed = feedparser.parse(res) - entries = feed.entries - for entry in entries: - entry["_target_name"] = feed.feed.title - return feed.entries + res = await self.client.get(target, timeout=10.0) + feed = feedparser.parse(res) + entries = feed.entries + for entry in entries: + entry["_target_name"] = feed.feed.title + return feed.entries async def parse(self, raw_post: RawPost) -> Post: text = raw_post.get("title", "") + "\n" if raw_post.get("title") else "" diff --git a/src/plugins/nonebot_bison/platform/weibo.py b/src/plugins/nonebot_bison/platform/weibo.py index 373e616..ee12ad2 100644 --- a/src/plugins/nonebot_bison/platform/weibo.py +++ b/src/plugins/nonebot_bison/platform/weibo.py @@ -68,7 +68,7 @@ class Weibo(NewMessage): ) res_data = json.loads(res.text) if not res_data["ok"]: - return [] + raise ApiError(res.request.url) custom_filter: Callable[[RawPost], bool] = lambda d: d["card_type"] == 9 return list(filter(custom_filter, res_data["data"]["cards"])) diff --git a/src/plugins/nonebot_bison/scheduler/scheduler.py b/src/plugins/nonebot_bison/scheduler/scheduler.py index 422f453..3232c81 100644 --- a/src/plugins/nonebot_bison/scheduler/scheduler.py +++ b/src/plugins/nonebot_bison/scheduler/scheduler.py @@ -9,7 +9,7 @@ from ..config import config from ..platform import platform_manager from ..send import send_msgs from ..types import Target -from ..utils import SchedulerConfig +from ..utils import ProcessContext, SchedulerConfig from .aps import aps @@ -78,6 +78,7 @@ class Scheduler: return cur_max_schedulable async def exec_fetch(self): + context = ProcessContext() if not (schedulable := await self.get_next_schedulable()): return logger.debug( @@ -86,12 +87,22 @@ class Scheduler: send_userinfo_list = await config.get_platform_target_subscribers( schedulable.platform_name, schedulable.target ) - platform_obj = platform_manager[schedulable.platform_name]( - await self.scheduler_config_obj.get_client(schedulable.target) - ) - to_send = await platform_obj.do_fetch_new_post( - schedulable.target, send_userinfo_list - ) + + client = await self.scheduler_config_obj.get_client(schedulable.target) + context.register_to_client(client) + + try: + platform_obj = platform_manager[schedulable.platform_name](context, client) + to_send = await platform_obj.do_fetch_new_post( + schedulable.target, send_userinfo_list + ) + except Exception as err: + records = context.gen_req_records() + for record in records: + logger.warning("API request record: " + record) + err.args += (records,) + raise + if not to_send: return bot = nonebot.get_bot() diff --git a/src/plugins/nonebot_bison/types.py b/src/plugins/nonebot_bison/types.py index b4df92d..ad9154e 100644 --- a/src/plugins/nonebot_bison/types.py +++ b/src/plugins/nonebot_bison/types.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from datetime import time from typing import Any, Literal, NamedTuple, NewType +from httpx import URL from pydantic import BaseModel RawPost = NewType("RawPost", Any) @@ -45,3 +46,9 @@ class PlatformWeightConfigResp(BaseModel): target_name: str platform_name: str weight: WeightConfig + + +class ApiError(Exception): + def __init__(self, url: URL) -> None: + msg = f"api {url} error" + super().__init__(msg) diff --git a/src/plugins/nonebot_bison/utils/__init__.py b/src/plugins/nonebot_bison/utils/__init__.py index 4a9c2c2..2b04d34 100644 --- a/src/plugins/nonebot_bison/utils/__init__.py +++ b/src/plugins/nonebot_bison/utils/__init__.py @@ -9,6 +9,7 @@ from nonebot.log import default_format, logger from nonebot.plugin import require from ..plugin_config import plugin_config +from .context import ProcessContext from .http import http_client from .scheduler_config import SchedulerConfig, scheduler @@ -16,6 +17,7 @@ __all__ = [ "http_client", "Singleton", "parse_text", + "ProcessContext", "html_to_text", "SchedulerConfig", "scheduler", diff --git a/src/plugins/nonebot_bison/utils/context.py b/src/plugins/nonebot_bison/utils/context.py new file mode 100644 index 0000000..d2eb4cd --- /dev/null +++ b/src/plugins/nonebot_bison/utils/context.py @@ -0,0 +1,40 @@ +from base64 import b64encode + +from httpx import AsyncClient, Response + + +class ProcessContext: + reqs: list[Response] + + def __init__(self) -> None: + self.reqs = [] + + def log_response(self, resp: Response): + self.reqs.append(resp) + + def register_to_client(self, client: AsyncClient): + async def _log_to_ctx(r: Response): + self.log_response(r) + + hooks = { + "response": [_log_to_ctx], + } + client.event_hooks = hooks + + def _should_print_content(self, r: Response) -> bool: + content_type = r.headers["content-type"] + if content_type.startswith("text"): + return True + if "json" in content_type: + return True + return False + + def gen_req_records(self) -> list[str]: + res = [] + for req in self.reqs: + if self._should_print_content(req): + log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} {req.text}" + else: + log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} b64encoded: {b64encode(req.content[:50]).decode()}" + res.append(log_content) + return res diff --git a/tests/platforms/test_arknights.py b/tests/platforms/test_arknights.py index d4eda02..5b19c07 100644 --- a/tests/platforms/test_arknights.py +++ b/tests/platforms/test_arknights.py @@ -9,8 +9,9 @@ from .utils import get_file, get_json @pytest.fixture def arknights(app: App): from nonebot_bison.platform import platform_manager + from nonebot_bison.utils import ProcessContext - return platform_manager["arknights"](AsyncClient()) + return platform_manager["arknights"](ProcessContext(), AsyncClient()) @pytest.fixture(scope="module") diff --git a/tests/platforms/test_bilibili.py b/tests/platforms/test_bilibili.py index c074dc8..b1fd1b6 100644 --- a/tests/platforms/test_bilibili.py +++ b/tests/platforms/test_bilibili.py @@ -22,8 +22,9 @@ if typing.TYPE_CHECKING: @pytest.fixture def bilibili(app: App): from nonebot_bison.platform import platform_manager + from nonebot_bison.utils import ProcessContext - return platform_manager["bilibili"](AsyncClient()) + return platform_manager["bilibili"](ProcessContext(), AsyncClient()) @pytest.mark.asyncio diff --git a/tests/platforms/test_bilibili_bangumi.py b/tests/platforms/test_bilibili_bangumi.py index b5b6f7f..226d98f 100644 --- a/tests/platforms/test_bilibili_bangumi.py +++ b/tests/platforms/test_bilibili_bangumi.py @@ -14,8 +14,9 @@ if typing.TYPE_CHECKING: @pytest.fixture def bili_bangumi(app: App): from nonebot_bison.platform import platform_manager + from nonebot_bison.utils import ProcessContext - return platform_manager["bilibili-bangumi"](AsyncClient()) + return platform_manager["bilibili-bangumi"](ProcessContext(), AsyncClient()) @pytest.mark.asyncio diff --git a/tests/platforms/test_bilibili_live.py b/tests/platforms/test_bilibili_live.py index f5a38c0..2340350 100644 --- a/tests/platforms/test_bilibili_live.py +++ b/tests/platforms/test_bilibili_live.py @@ -9,8 +9,9 @@ from .utils import get_json @pytest.fixture def bili_live(app: App): from nonebot_bison.platform import platform_manager + from nonebot_bison.utils import ProcessContext - return platform_manager["bilibili-live"](AsyncClient()) + return platform_manager["bilibili-live"](ProcessContext(), AsyncClient()) @pytest.fixture diff --git a/tests/platforms/test_ff14.py b/tests/platforms/test_ff14.py index ed6ac2e..5798071 100644 --- a/tests/platforms/test_ff14.py +++ b/tests/platforms/test_ff14.py @@ -9,8 +9,9 @@ from .utils import get_json @pytest.fixture def ff14(app: App): from nonebot_bison.platform import platform_manager + from nonebot_bison.utils import ProcessContext - return platform_manager["ff14"](AsyncClient()) + return platform_manager["ff14"](ProcessContext(), AsyncClient()) @pytest.fixture(scope="module") diff --git a/tests/platforms/test_mcbbsnews.py b/tests/platforms/test_mcbbsnews.py index e551c10..1b594f3 100644 --- a/tests/platforms/test_mcbbsnews.py +++ b/tests/platforms/test_mcbbsnews.py @@ -9,8 +9,9 @@ from .utils import get_file, get_json @pytest.fixture def mcbbsnews(app: App): from nonebot_bison.platform import platform_manager + from nonebot_bison.utils import ProcessContext - return platform_manager["mcbbsnews"](AsyncClient()) + return platform_manager["mcbbsnews"](ProcessContext(), AsyncClient()) @pytest.fixture(scope="module") diff --git a/tests/platforms/test_ncm_artist.py b/tests/platforms/test_ncm_artist.py index 78296de..ce89a77 100644 --- a/tests/platforms/test_ncm_artist.py +++ b/tests/platforms/test_ncm_artist.py @@ -15,8 +15,9 @@ if typing.TYPE_CHECKING: @pytest.fixture def ncm_artist(app: App): from nonebot_bison.platform import platform_manager + from nonebot_bison.utils import ProcessContext - return platform_manager["ncm-artist"](AsyncClient()) + return platform_manager["ncm-artist"](ProcessContext(), AsyncClient()) @pytest.fixture(scope="module") diff --git a/tests/platforms/test_ncm_radio.py b/tests/platforms/test_ncm_radio.py index 37b2160..034032b 100644 --- a/tests/platforms/test_ncm_radio.py +++ b/tests/platforms/test_ncm_radio.py @@ -15,8 +15,9 @@ if typing.TYPE_CHECKING: @pytest.fixture def ncm_radio(app: App): from nonebot_bison.platform import platform_manager + from nonebot_bison.utils import ProcessContext - return platform_manager["ncm-radio"](AsyncClient()) + return platform_manager["ncm-radio"](ProcessContext(), AsyncClient()) @pytest.fixture(scope="module") diff --git a/tests/platforms/test_platform.py b/tests/platforms/test_platform.py index 27f37c1..2e03d3d 100644 --- a/tests/platforms/test_platform.py +++ b/tests/platforms/test_platform.py @@ -1,10 +1,13 @@ from time import time -from typing import Any, Optional +from typing import TYPE_CHECKING, Any import pytest from httpx import AsyncClient from nonebug.app import App +if TYPE_CHECKING: + from nonebot_bison.platform import Platform + now = time() passed = now - 3 * 60 * 60 @@ -56,9 +59,6 @@ def mock_platform_without_cats_tags(app: App): sub_index = 0 - def __init__(self, client): - super().__init__(client) - @classmethod async def get_target_name(cls, client, _: "Target"): return "MockPlatform" @@ -117,9 +117,6 @@ def mock_platform(app: App): sub_index = 0 - def __init__(self, client): - super().__init__(client) - @staticmethod async def get_target_name(_: "Target"): return "MockPlatform" @@ -187,9 +184,6 @@ def mock_platform_no_target(app: App, mock_scheduler_conf): sub_index = 0 - def __init__(self, client): - super().__init__(client) - @staticmethod async def get_target_name(_: "Target"): return "MockPlatform" @@ -250,9 +244,6 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf): sub_index = 0 - def __init__(self, client): - super().__init__(client) - @classmethod async def get_target_name(cls, client, _: "Target"): return "MockPlatform" @@ -319,9 +310,6 @@ def mock_status_change(app: App): sub_index = 0 - def __init__(self, client): - super().__init__(client) - @classmethod async def get_status(cls, _: "Target"): if cls.sub_index == 0: @@ -353,11 +341,15 @@ def mock_status_change(app: App): async def test_new_message_target_without_cats_tags( mock_platform_without_cats_tags, user_info_factory ): - res1 = await mock_platform_without_cats_tags(AsyncClient()).fetch_new_post( - "dummy", [user_info_factory([1, 2], [])] - ) + from nonebot_bison.utils import ProcessContext + + res1 = await mock_platform_without_cats_tags( + ProcessContext(), AsyncClient() + ).fetch_new_post("dummy", [user_info_factory([1, 2], [])]) assert len(res1) == 0 - res2 = await mock_platform_without_cats_tags(AsyncClient()).fetch_new_post( + res2 = await mock_platform_without_cats_tags( + ProcessContext(), AsyncClient() + ).fetch_new_post( "dummy", [ user_info_factory([], []), @@ -372,11 +364,13 @@ async def test_new_message_target_without_cats_tags( @pytest.mark.asyncio async def test_new_message_target(mock_platform, user_info_factory): - res1 = await mock_platform(AsyncClient()).fetch_new_post( + from nonebot_bison.utils import ProcessContext + + res1 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post( "dummy", [user_info_factory([1, 2], [])] ) assert len(res1) == 0 - res2 = await mock_platform(AsyncClient()).fetch_new_post( + res2 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post( "dummy", [ user_info_factory([1, 2], []), @@ -401,11 +395,15 @@ async def test_new_message_target(mock_platform, user_info_factory): @pytest.mark.asyncio async def test_new_message_no_target(mock_platform_no_target, user_info_factory): - res1 = await mock_platform_no_target(AsyncClient()).fetch_new_post( - "dummy", [user_info_factory([1, 2], [])] - ) + from nonebot_bison.utils import ProcessContext + + res1 = await mock_platform_no_target( + ProcessContext(), AsyncClient() + ).fetch_new_post("dummy", [user_info_factory([1, 2], [])]) assert len(res1) == 0 - res2 = await mock_platform_no_target(AsyncClient()).fetch_new_post( + res2 = await mock_platform_no_target( + ProcessContext(), AsyncClient() + ).fetch_new_post( "dummy", [ user_info_factory([1, 2], []), @@ -426,26 +424,28 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory) assert "p2" in id_set_1 and "p3" in id_set_1 assert "p2" in id_set_2 assert "p2" in id_set_3 - res3 = await mock_platform_no_target(AsyncClient()).fetch_new_post( - "dummy", [user_info_factory([1, 2], [])] - ) + res3 = await mock_platform_no_target( + ProcessContext(), AsyncClient() + ).fetch_new_post("dummy", [user_info_factory([1, 2], [])]) assert len(res3) == 0 @pytest.mark.asyncio async def test_status_change(mock_status_change, user_info_factory): - res1 = await mock_status_change(AsyncClient()).fetch_new_post( + from nonebot_bison.utils import ProcessContext + + res1 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( "dummy", [user_info_factory([1, 2], [])] ) assert len(res1) == 0 - res2 = await mock_status_change(AsyncClient()).fetch_new_post( + res2 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( "dummy", [user_info_factory([1, 2], [])] ) assert len(res2) == 1 posts = res2[0][1] assert len(posts) == 1 assert posts[0].text == "on" - res3 = await mock_status_change(AsyncClient()).fetch_new_post( + res3 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( "dummy", [ user_info_factory([1, 2], []), @@ -456,7 +456,7 @@ async def test_status_change(mock_status_change, user_info_factory): assert len(res3[0][1]) == 1 assert res3[0][1][0].text == "off" assert len(res3[1][1]) == 0 - res4 = await mock_status_change(AsyncClient()).fetch_new_post( + res4 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post( "dummy", [user_info_factory([1, 2], [])] ) assert len(res4) == 0 @@ -473,11 +473,12 @@ async def test_group( from nonebot_bison.platform.platform import make_no_target_group from nonebot_bison.post import Post from nonebot_bison.types import Category, RawPost, Tag, Target + from nonebot_bison.utils import ProcessContext group_platform_class = make_no_target_group( [mock_platform_no_target, mock_platform_no_target_2] ) - group_platform = group_platform_class(None) + group_platform = group_platform_class(ProcessContext(), None) res1 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])]) assert len(res1) == 0 res2 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])]) diff --git a/tests/platforms/test_platform_tag_filter.py b/tests/platforms/test_platform_tag_filter.py index 6e340b7..255dddd 100644 --- a/tests/platforms/test_platform_tag_filter.py +++ b/tests/platforms/test_platform_tag_filter.py @@ -14,8 +14,9 @@ def test_cases(): @pytest.mark.asyncio async def test_filter_user_custom_tag(app: App, test_cases): from nonebot_bison.platform import platform_manager + from nonebot_bison.utils import ProcessContext - bilibili = platform_manager["bilibili"](AsyncClient()) + bilibili = platform_manager["bilibili"](ProcessContext(), AsyncClient()) for case in test_cases: res = bilibili.is_banned_post(**case["case"]) assert res == case["result"] @@ -25,8 +26,9 @@ async def test_filter_user_custom_tag(app: App, test_cases): @pytest.mark.asyncio async def test_tag_separator(app: App): from nonebot_bison.platform import platform_manager + from nonebot_bison.utils import ProcessContext - bilibili = platform_manager["bilibili"](AsyncClient()) + bilibili = platform_manager["bilibili"](ProcessContext(), AsyncClient()) tags = ["~111", "222", "333", "~444", "555"] res = bilibili.tag_separator(tags) assert res[0] == ["222", "333", "555"] diff --git a/tests/platforms/test_weibo.py b/tests/platforms/test_weibo.py index 0a9b29c..c6bbd5b 100644 --- a/tests/platforms/test_weibo.py +++ b/tests/platforms/test_weibo.py @@ -21,8 +21,9 @@ image_cdn_router = respx.route( @pytest.fixture def weibo(app: App): from nonebot_bison.platform import platform_manager + from nonebot_bison.utils import ProcessContext - return platform_manager["weibo"](AsyncClient()) + return platform_manager["weibo"](ProcessContext(), AsyncClient()) @pytest.fixture(scope="module") diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 0000000..d98eff6 --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,20 @@ +import httpx +import respx +from nonebug.app import App + + +@respx.mock +async def test_http_error(app: App): + from nonebot_bison.utils import ProcessContext, http_client + + example_route = respx.get("https://example.com") + example_route.mock(httpx.Response(403, json={"error": "gg"})) + + ctx = ProcessContext() + async with http_client() as client: + ctx.register_to_client(client) + await client.get("https://example.com") + + assert ctx.gen_req_records() == [ + "https://example.com Headers({'host': 'example.com', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.103 Safari/537.36'}) | [403] Headers({'content-length': '15', 'content-type': 'application/json'}) {\"error\": \"gg\"}" + ]