add context to log http error

This commit is contained in:
felinae98 2022-11-24 13:12:56 +08:00
parent aa810cc903
commit bd679914eb
22 changed files with 218 additions and 132 deletions

View File

@ -1,21 +1,17 @@
import functools
import json import json
import re import re
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Callable, Optional from typing import Any, Optional
import httpx
from httpx import AsyncClient from httpx import AsyncClient
from nonebot.log import logger from nonebot.log import logger
from typing_extensions import Self from typing_extensions import Self
from ..plugin_config import plugin_config
from ..post import Post from ..post import Post
from ..types import Category, RawPost, Tag, Target from ..types import ApiError, Category, RawPost, Tag, Target
from ..utils import SchedulerConfig from ..utils import SchedulerConfig
from ..utils.http import http_args
from .platform import CategoryNotSupport, NewMessage, StatusChange from .platform import CategoryNotSupport, NewMessage, StatusChange
@ -105,7 +101,7 @@ class Bilibili(NewMessage):
if res_dict["code"] == 0: if res_dict["code"] == 0:
return res_dict["data"].get("cards") return res_dict["data"].get("cards")
else: else:
return [] raise ApiError(res.request.url)
def get_id(self, post: RawPost) -> Any: def get_id(self, post: RawPost) -> Any:
return post["desc"]["dynamic_id"] return post["desc"]["dynamic_id"]
@ -306,7 +302,7 @@ class Bilibililive(StatusChange):
self.name, self.name,
text=title, text=title,
url=url, url=url,
pics=pic, pics=list(pic),
target_name=target_name, target_name=target_name,
compress=True, compress=True,
) )
@ -384,14 +380,14 @@ class BilibiliBangumi(StatusChange):
lastest_episode = detail_dict["result"]["episodes"] lastest_episode = detail_dict["result"]["episodes"]
url = lastest_episode["link"] url = lastest_episode["link"]
pic = [lastest_episode["cover"]] pic: list[str] = [lastest_episode["cover"]]
target_name = detail_dict["result"]["season_title"] target_name = detail_dict["result"]["season_title"]
text = lastest_episode["share_copy"] text = lastest_episode["share_copy"]
return Post( return Post(
self.name, self.name,
text=text, text=text,
url=url, url=url,
pics=pic, pics=list(pic),
target_name=target_name, target_name=target_name,
compress=True, compress=True,
) )

View File

@ -4,7 +4,7 @@ from httpx import AsyncClient
from ..post import Post from ..post import Post
from ..types import RawPost, Target from ..types import RawPost, Target
from ..utils import http_client, scheduler from ..utils import scheduler
from .platform import NewMessage from .platform import NewMessage
@ -27,11 +27,10 @@ class FF14(NewMessage):
return "最终幻想XIV官方公告" return "最终幻想XIV官方公告"
async def get_sub_list(self, _) -> list[RawPost]: async def get_sub_list(self, _) -> list[RawPost]:
async with http_client() as client: raw_data = await self.client.get(
raw_data = await client.get( "https://ff.web.sdo.com/inc/newdata.ashx?url=List?gameCode=ff&category=5309,5310,5311,5312,5313&pageIndex=0&pageSize=5"
"https://ff.web.sdo.com/inc/newdata.ashx?url=List?gameCode=ff&category=5309,5310,5311,5312,5313&pageIndex=0&pageSize=5" )
) return raw_data.json()["Data"]
return raw_data.json()["Data"]
def get_id(self, post: RawPost) -> Any: def get_id(self, post: RawPost) -> Any:
"""用发布时间当作 ID """用发布时间当作 ID

View File

@ -4,8 +4,8 @@ from typing import Any, Optional
from httpx import AsyncClient from httpx import AsyncClient
from ..post import Post from ..post import Post
from ..types import RawPost, Target from ..types import ApiError, RawPost, Target
from ..utils import SchedulerConfig, http_client from ..utils import SchedulerConfig
from .platform import NewMessage from .platform import NewMessage
@ -32,15 +32,14 @@ class NcmArtist(NewMessage):
async def get_target_name( async def get_target_name(
cls, client: AsyncClient, target: Target cls, client: AsyncClient, target: Target
) -> Optional[str]: ) -> Optional[str]:
async with http_client() as client: res = await client.get(
res = await client.get( "https://music.163.com/api/artist/albums/{}".format(target),
"https://music.163.com/api/artist/albums/{}".format(target), headers={"Referer": "https://music.163.com/"},
headers={"Referer": "https://music.163.com/"}, )
) res_data = res.json()
res_data = res.json() if res_data["code"] != 200:
if res_data["code"] != 200: raise ApiError(res.request.url)
return return res_data["artist"]["name"]
return res_data["artist"]["name"]
@classmethod @classmethod
async def parse_target(cls, target_text: str) -> Target: async def parse_target(cls, target_text: str) -> Target:
@ -54,16 +53,15 @@ class NcmArtist(NewMessage):
raise cls.ParseTargetException() raise cls.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]: async def get_sub_list(self, target: Target) -> list[RawPost]:
async with http_client() as client: res = await self.client.get(
res = await client.get( "https://music.163.com/api/artist/albums/{}".format(target),
"https://music.163.com/api/artist/albums/{}".format(target), headers={"Referer": "https://music.163.com/"},
headers={"Referer": "https://music.163.com/"}, )
) res_data = res.json()
res_data = res.json() if res_data["code"] != 200:
if res_data["code"] != 200: return []
return [] else:
else: return res_data["hotAlbums"]
return res_data["hotAlbums"]
def get_id(self, post: RawPost) -> Any: def get_id(self, post: RawPost) -> Any:
return post["id"] return post["id"]
@ -97,16 +95,15 @@ class NcmRadio(NewMessage):
async def get_target_name( async def get_target_name(
cls, client: AsyncClient, target: Target cls, client: AsyncClient, target: Target
) -> Optional[str]: ) -> Optional[str]:
async with http_client() as client: res = await client.post(
res = await client.post( "http://music.163.com/api/dj/program/byradio",
"http://music.163.com/api/dj/program/byradio", headers={"Referer": "https://music.163.com/"},
headers={"Referer": "https://music.163.com/"}, data={"radioId": target, "limit": 1000, "offset": 0},
data={"radioId": target, "limit": 1000, "offset": 0}, )
) res_data = res.json()
res_data = res.json() if res_data["code"] != 200 or res_data["programs"] == 0:
if res_data["code"] != 200 or res_data["programs"] == 0: return
return return res_data["programs"][0]["radio"]["name"]
return res_data["programs"][0]["radio"]["name"]
@classmethod @classmethod
async def parse_target(cls, target_text: str) -> Target: async def parse_target(cls, target_text: str) -> Target:
@ -120,17 +117,16 @@ class NcmRadio(NewMessage):
raise cls.ParseTargetException() raise cls.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]: async def get_sub_list(self, target: Target) -> list[RawPost]:
async with http_client() as client: res = await self.client.post(
res = await client.post( "http://music.163.com/api/dj/program/byradio",
"http://music.163.com/api/dj/program/byradio", headers={"Referer": "https://music.163.com/"},
headers={"Referer": "https://music.163.com/"}, data={"radioId": target, "limit": 1000, "offset": 0},
data={"radioId": target, "limit": 1000, "offset": 0}, )
) res_data = res.json()
res_data = res.json() if res_data["code"] != 200:
if res_data["code"] != 200: return []
return [] else:
else: return res_data["programs"]
return res_data["programs"]
def get_id(self, post: RawPost) -> Any: def get_id(self, post: RawPost) -> Any:
return post["id"] return post["id"]

View File

@ -14,7 +14,7 @@ from nonebot.log import logger
from ..plugin_config import plugin_config from ..plugin_config import plugin_config
from ..post import Post from ..post import Post
from ..types import Category, RawPost, Tag, Target, User, UserSubInfo from ..types import Category, RawPost, Tag, Target, User, UserSubInfo
from ..utils.scheduler_config import SchedulerConfig from ..utils import ProcessContext, SchedulerConfig
class CategoryNotSupport(Exception): class CategoryNotSupport(Exception):
@ -57,6 +57,7 @@ class PlatformABCMeta(PlatformMeta, ABC):
class Platform(metaclass=PlatformABCMeta, base=True): class Platform(metaclass=PlatformABCMeta, base=True):
scheduler: Type[SchedulerConfig] scheduler: Type[SchedulerConfig]
ctx: ProcessContext
is_common: bool is_common: bool
enabled: bool enabled: bool
name: str name: str
@ -99,7 +100,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
return [] return []
except json.JSONDecodeError as err: except json.JSONDecodeError as err:
logger.warning(f"json error, parsing: {err.doc}") logger.warning(f"json error, parsing: {err.doc}")
return [] raise err
@abstractmethod @abstractmethod
async def parse(self, raw_post: RawPost) -> Post: async def parse(self, raw_post: RawPost) -> Post:
@ -109,9 +110,10 @@ class Platform(metaclass=PlatformABCMeta, base=True):
"actually function called" "actually function called"
return await self.parse(raw_post) return await self.parse(raw_post)
def __init__(self, client: AsyncClient): def __init__(self, context: ProcessContext, client: AsyncClient):
super().__init__() super().__init__()
self.client = client self.client = client
self.ctx = context
class ParseTargetException(Exception): class ParseTargetException(Exception):
pass pass
@ -209,8 +211,8 @@ class Platform(metaclass=PlatformABCMeta, base=True):
class MessageProcess(Platform, abstract=True): class MessageProcess(Platform, abstract=True):
"General message process fetch, parse, filter progress" "General message process fetch, parse, filter progress"
def __init__(self, client: AsyncClient): def __init__(self, ctx: ProcessContext, client: AsyncClient):
super().__init__(client) super().__init__(ctx, client)
self.parse_cache: dict[Any, Post] = dict() self.parse_cache: dict[Any, Post] = dict()
@abstractmethod @abstractmethod
@ -254,6 +256,9 @@ class MessageProcess(Platform, abstract=True):
try: try:
self.get_category(raw_post) self.get_category(raw_post)
except CategoryNotSupport: except CategoryNotSupport:
msgs = self.ctx.gen_req_records()
for m in msgs:
logger.warning(m)
continue continue
except NotImplementedError: except NotImplementedError:
pass pass
@ -342,7 +347,7 @@ class StatusChange(Platform, abstract=True):
new_status = await self.get_status(target) new_status = await self.get_status(target)
except self.FetchError as err: except self.FetchError as err:
logger.warning(f"fetching {self.name}-{target} error: {err}") logger.warning(f"fetching {self.name}-{target} error: {err}")
return [] raise
res = [] res = []
if old_status := self.get_stored_data(target): if old_status := self.get_stored_data(target):
diff = self.compare_status(target, old_status, new_status) diff = self.compare_status(target, old_status, new_status)
@ -420,11 +425,11 @@ def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
"Platform scheduler for {} not fit".format(platform_name) "Platform scheduler for {} not fit".format(platform_name)
) )
def __init__(self: "NoTargetGroup", client: AsyncClient): def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient):
Platform.__init__(self, client) Platform.__init__(self, ctx, client)
self.platform_obj_list = [] self.platform_obj_list = []
for platform_class in self.platform_list: for platform_class in self.platform_list:
self.platform_obj_list.append(platform_class(client)) self.platform_obj_list.append(platform_class(ctx, client))
def __str__(self: "NoTargetGroup") -> str: def __str__(self: "NoTargetGroup") -> str:
return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]" return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]"

View File

@ -7,7 +7,7 @@ from httpx import AsyncClient
from ..post import Post from ..post import Post
from ..types import RawPost, Target from ..types import RawPost, Target
from ..utils import http_client, scheduler from ..utils import scheduler
from .platform import NewMessage from .platform import NewMessage
@ -26,10 +26,9 @@ class Rss(NewMessage):
async def get_target_name( async def get_target_name(
cls, client: AsyncClient, target: Target cls, client: AsyncClient, target: Target
) -> Optional[str]: ) -> Optional[str]:
async with http_client() as client: res = await client.get(target, timeout=10.0)
res = await client.get(target, timeout=10.0) feed = feedparser.parse(res.text)
feed = feedparser.parse(res.text) return feed["feed"]["title"]
return feed["feed"]["title"]
def get_date(self, post: RawPost) -> int: def get_date(self, post: RawPost) -> int:
return calendar.timegm(post.published_parsed) return calendar.timegm(post.published_parsed)
@ -38,13 +37,12 @@ class Rss(NewMessage):
return post.id return post.id
async def get_sub_list(self, target: Target) -> list[RawPost]: async def get_sub_list(self, target: Target) -> list[RawPost]:
async with http_client() as client: res = await self.client.get(target, timeout=10.0)
res = await client.get(target, timeout=10.0) feed = feedparser.parse(res)
feed = feedparser.parse(res) entries = feed.entries
entries = feed.entries for entry in entries:
for entry in entries: entry["_target_name"] = feed.feed.title
entry["_target_name"] = feed.feed.title return feed.entries
return feed.entries
async def parse(self, raw_post: RawPost) -> Post: async def parse(self, raw_post: RawPost) -> Post:
text = raw_post.get("title", "") + "\n" if raw_post.get("title") else "" text = raw_post.get("title", "") + "\n" if raw_post.get("title") else ""

View File

@ -68,7 +68,7 @@ class Weibo(NewMessage):
) )
res_data = json.loads(res.text) res_data = json.loads(res.text)
if not res_data["ok"]: if not res_data["ok"]:
return [] raise ApiError(res.request.url)
custom_filter: Callable[[RawPost], bool] = lambda d: d["card_type"] == 9 custom_filter: Callable[[RawPost], bool] = lambda d: d["card_type"] == 9
return list(filter(custom_filter, res_data["data"]["cards"])) return list(filter(custom_filter, res_data["data"]["cards"]))

View File

@ -9,7 +9,7 @@ from ..config import config
from ..platform import platform_manager from ..platform import platform_manager
from ..send import send_msgs from ..send import send_msgs
from ..types import Target from ..types import Target
from ..utils import SchedulerConfig from ..utils import ProcessContext, SchedulerConfig
from .aps import aps from .aps import aps
@ -78,6 +78,7 @@ class Scheduler:
return cur_max_schedulable return cur_max_schedulable
async def exec_fetch(self): async def exec_fetch(self):
context = ProcessContext()
if not (schedulable := await self.get_next_schedulable()): if not (schedulable := await self.get_next_schedulable()):
return return
logger.debug( logger.debug(
@ -86,12 +87,22 @@ class Scheduler:
send_userinfo_list = await config.get_platform_target_subscribers( send_userinfo_list = await config.get_platform_target_subscribers(
schedulable.platform_name, schedulable.target schedulable.platform_name, schedulable.target
) )
platform_obj = platform_manager[schedulable.platform_name](
await self.scheduler_config_obj.get_client(schedulable.target) client = await self.scheduler_config_obj.get_client(schedulable.target)
) context.register_to_client(client)
to_send = await platform_obj.do_fetch_new_post(
schedulable.target, send_userinfo_list try:
) platform_obj = platform_manager[schedulable.platform_name](context, client)
to_send = await platform_obj.do_fetch_new_post(
schedulable.target, send_userinfo_list
)
except Exception as err:
records = context.gen_req_records()
for record in records:
logger.warning("API request record: " + record)
err.args += (records,)
raise
if not to_send: if not to_send:
return return
bot = nonebot.get_bot() bot = nonebot.get_bot()

View File

@ -2,6 +2,7 @@ from dataclasses import dataclass
from datetime import time from datetime import time
from typing import Any, Literal, NamedTuple, NewType from typing import Any, Literal, NamedTuple, NewType
from httpx import URL
from pydantic import BaseModel from pydantic import BaseModel
RawPost = NewType("RawPost", Any) RawPost = NewType("RawPost", Any)
@ -45,3 +46,9 @@ class PlatformWeightConfigResp(BaseModel):
target_name: str target_name: str
platform_name: str platform_name: str
weight: WeightConfig weight: WeightConfig
class ApiError(Exception):
def __init__(self, url: URL) -> None:
msg = f"api {url} error"
super().__init__(msg)

View File

@ -9,6 +9,7 @@ from nonebot.log import default_format, logger
from nonebot.plugin import require from nonebot.plugin import require
from ..plugin_config import plugin_config from ..plugin_config import plugin_config
from .context import ProcessContext
from .http import http_client from .http import http_client
from .scheduler_config import SchedulerConfig, scheduler from .scheduler_config import SchedulerConfig, scheduler
@ -16,6 +17,7 @@ __all__ = [
"http_client", "http_client",
"Singleton", "Singleton",
"parse_text", "parse_text",
"ProcessContext",
"html_to_text", "html_to_text",
"SchedulerConfig", "SchedulerConfig",
"scheduler", "scheduler",

View File

@ -0,0 +1,40 @@
from base64 import b64encode
from httpx import AsyncClient, Response
class ProcessContext:
reqs: list[Response]
def __init__(self) -> None:
self.reqs = []
def log_response(self, resp: Response):
self.reqs.append(resp)
def register_to_client(self, client: AsyncClient):
async def _log_to_ctx(r: Response):
self.log_response(r)
hooks = {
"response": [_log_to_ctx],
}
client.event_hooks = hooks
def _should_print_content(self, r: Response) -> bool:
content_type = r.headers["content-type"]
if content_type.startswith("text"):
return True
if "json" in content_type:
return True
return False
def gen_req_records(self) -> list[str]:
res = []
for req in self.reqs:
if self._should_print_content(req):
log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} {req.text}"
else:
log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} b64encoded: {b64encode(req.content[:50]).decode()}"
res.append(log_content)
return res

View File

@ -9,8 +9,9 @@ from .utils import get_file, get_json
@pytest.fixture @pytest.fixture
def arknights(app: App): def arknights(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils import ProcessContext
return platform_manager["arknights"](AsyncClient()) return platform_manager["arknights"](ProcessContext(), AsyncClient())
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -22,8 +22,9 @@ if typing.TYPE_CHECKING:
@pytest.fixture @pytest.fixture
def bilibili(app: App): def bilibili(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils import ProcessContext
return platform_manager["bilibili"](AsyncClient()) return platform_manager["bilibili"](ProcessContext(), AsyncClient())
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -14,8 +14,9 @@ if typing.TYPE_CHECKING:
@pytest.fixture @pytest.fixture
def bili_bangumi(app: App): def bili_bangumi(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils import ProcessContext
return platform_manager["bilibili-bangumi"](AsyncClient()) return platform_manager["bilibili-bangumi"](ProcessContext(), AsyncClient())
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -9,8 +9,9 @@ from .utils import get_json
@pytest.fixture @pytest.fixture
def bili_live(app: App): def bili_live(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils import ProcessContext
return platform_manager["bilibili-live"](AsyncClient()) return platform_manager["bilibili-live"](ProcessContext(), AsyncClient())
@pytest.fixture @pytest.fixture

View File

@ -9,8 +9,9 @@ from .utils import get_json
@pytest.fixture @pytest.fixture
def ff14(app: App): def ff14(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils import ProcessContext
return platform_manager["ff14"](AsyncClient()) return platform_manager["ff14"](ProcessContext(), AsyncClient())
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -9,8 +9,9 @@ from .utils import get_file, get_json
@pytest.fixture @pytest.fixture
def mcbbsnews(app: App): def mcbbsnews(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils import ProcessContext
return platform_manager["mcbbsnews"](AsyncClient()) return platform_manager["mcbbsnews"](ProcessContext(), AsyncClient())
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -15,8 +15,9 @@ if typing.TYPE_CHECKING:
@pytest.fixture @pytest.fixture
def ncm_artist(app: App): def ncm_artist(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils import ProcessContext
return platform_manager["ncm-artist"](AsyncClient()) return platform_manager["ncm-artist"](ProcessContext(), AsyncClient())
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -15,8 +15,9 @@ if typing.TYPE_CHECKING:
@pytest.fixture @pytest.fixture
def ncm_radio(app: App): def ncm_radio(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils import ProcessContext
return platform_manager["ncm-radio"](AsyncClient()) return platform_manager["ncm-radio"](ProcessContext(), AsyncClient())
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -1,10 +1,13 @@
from time import time from time import time
from typing import Any, Optional from typing import TYPE_CHECKING, Any
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
from nonebug.app import App from nonebug.app import App
if TYPE_CHECKING:
from nonebot_bison.platform import Platform
now = time() now = time()
passed = now - 3 * 60 * 60 passed = now - 3 * 60 * 60
@ -56,9 +59,6 @@ def mock_platform_without_cats_tags(app: App):
sub_index = 0 sub_index = 0
def __init__(self, client):
super().__init__(client)
@classmethod @classmethod
async def get_target_name(cls, client, _: "Target"): async def get_target_name(cls, client, _: "Target"):
return "MockPlatform" return "MockPlatform"
@ -117,9 +117,6 @@ def mock_platform(app: App):
sub_index = 0 sub_index = 0
def __init__(self, client):
super().__init__(client)
@staticmethod @staticmethod
async def get_target_name(_: "Target"): async def get_target_name(_: "Target"):
return "MockPlatform" return "MockPlatform"
@ -187,9 +184,6 @@ def mock_platform_no_target(app: App, mock_scheduler_conf):
sub_index = 0 sub_index = 0
def __init__(self, client):
super().__init__(client)
@staticmethod @staticmethod
async def get_target_name(_: "Target"): async def get_target_name(_: "Target"):
return "MockPlatform" return "MockPlatform"
@ -250,9 +244,6 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf):
sub_index = 0 sub_index = 0
def __init__(self, client):
super().__init__(client)
@classmethod @classmethod
async def get_target_name(cls, client, _: "Target"): async def get_target_name(cls, client, _: "Target"):
return "MockPlatform" return "MockPlatform"
@ -319,9 +310,6 @@ def mock_status_change(app: App):
sub_index = 0 sub_index = 0
def __init__(self, client):
super().__init__(client)
@classmethod @classmethod
async def get_status(cls, _: "Target"): async def get_status(cls, _: "Target"):
if cls.sub_index == 0: if cls.sub_index == 0:
@ -353,11 +341,15 @@ def mock_status_change(app: App):
async def test_new_message_target_without_cats_tags( async def test_new_message_target_without_cats_tags(
mock_platform_without_cats_tags, user_info_factory mock_platform_without_cats_tags, user_info_factory
): ):
res1 = await mock_platform_without_cats_tags(AsyncClient()).fetch_new_post( from nonebot_bison.utils import ProcessContext
"dummy", [user_info_factory([1, 2], [])]
) res1 = await mock_platform_without_cats_tags(
ProcessContext(), AsyncClient()
).fetch_new_post("dummy", [user_info_factory([1, 2], [])])
assert len(res1) == 0 assert len(res1) == 0
res2 = await mock_platform_without_cats_tags(AsyncClient()).fetch_new_post( res2 = await mock_platform_without_cats_tags(
ProcessContext(), AsyncClient()
).fetch_new_post(
"dummy", "dummy",
[ [
user_info_factory([], []), user_info_factory([], []),
@ -372,11 +364,13 @@ async def test_new_message_target_without_cats_tags(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_message_target(mock_platform, user_info_factory): async def test_new_message_target(mock_platform, user_info_factory):
res1 = await mock_platform(AsyncClient()).fetch_new_post( from nonebot_bison.utils import ProcessContext
res1 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])] "dummy", [user_info_factory([1, 2], [])]
) )
assert len(res1) == 0 assert len(res1) == 0
res2 = await mock_platform(AsyncClient()).fetch_new_post( res2 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy", "dummy",
[ [
user_info_factory([1, 2], []), user_info_factory([1, 2], []),
@ -401,11 +395,15 @@ async def test_new_message_target(mock_platform, user_info_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_message_no_target(mock_platform_no_target, user_info_factory): async def test_new_message_no_target(mock_platform_no_target, user_info_factory):
res1 = await mock_platform_no_target(AsyncClient()).fetch_new_post( from nonebot_bison.utils import ProcessContext
"dummy", [user_info_factory([1, 2], [])]
) res1 = await mock_platform_no_target(
ProcessContext(), AsyncClient()
).fetch_new_post("dummy", [user_info_factory([1, 2], [])])
assert len(res1) == 0 assert len(res1) == 0
res2 = await mock_platform_no_target(AsyncClient()).fetch_new_post( res2 = await mock_platform_no_target(
ProcessContext(), AsyncClient()
).fetch_new_post(
"dummy", "dummy",
[ [
user_info_factory([1, 2], []), user_info_factory([1, 2], []),
@ -426,26 +424,28 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory)
assert "p2" in id_set_1 and "p3" in id_set_1 assert "p2" in id_set_1 and "p3" in id_set_1
assert "p2" in id_set_2 assert "p2" in id_set_2
assert "p2" in id_set_3 assert "p2" in id_set_3
res3 = await mock_platform_no_target(AsyncClient()).fetch_new_post( res3 = await mock_platform_no_target(
"dummy", [user_info_factory([1, 2], [])] ProcessContext(), AsyncClient()
) ).fetch_new_post("dummy", [user_info_factory([1, 2], [])])
assert len(res3) == 0 assert len(res3) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_status_change(mock_status_change, user_info_factory): async def test_status_change(mock_status_change, user_info_factory):
res1 = await mock_status_change(AsyncClient()).fetch_new_post( from nonebot_bison.utils import ProcessContext
res1 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])] "dummy", [user_info_factory([1, 2], [])]
) )
assert len(res1) == 0 assert len(res1) == 0
res2 = await mock_status_change(AsyncClient()).fetch_new_post( res2 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])] "dummy", [user_info_factory([1, 2], [])]
) )
assert len(res2) == 1 assert len(res2) == 1
posts = res2[0][1] posts = res2[0][1]
assert len(posts) == 1 assert len(posts) == 1
assert posts[0].text == "on" assert posts[0].text == "on"
res3 = await mock_status_change(AsyncClient()).fetch_new_post( res3 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy", "dummy",
[ [
user_info_factory([1, 2], []), user_info_factory([1, 2], []),
@ -456,7 +456,7 @@ async def test_status_change(mock_status_change, user_info_factory):
assert len(res3[0][1]) == 1 assert len(res3[0][1]) == 1
assert res3[0][1][0].text == "off" assert res3[0][1][0].text == "off"
assert len(res3[1][1]) == 0 assert len(res3[1][1]) == 0
res4 = await mock_status_change(AsyncClient()).fetch_new_post( res4 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])] "dummy", [user_info_factory([1, 2], [])]
) )
assert len(res4) == 0 assert len(res4) == 0
@ -473,11 +473,12 @@ async def test_group(
from nonebot_bison.platform.platform import make_no_target_group from nonebot_bison.platform.platform import make_no_target_group
from nonebot_bison.post import Post from nonebot_bison.post import Post
from nonebot_bison.types import Category, RawPost, Tag, Target from nonebot_bison.types import Category, RawPost, Tag, Target
from nonebot_bison.utils import ProcessContext
group_platform_class = make_no_target_group( group_platform_class = make_no_target_group(
[mock_platform_no_target, mock_platform_no_target_2] [mock_platform_no_target, mock_platform_no_target_2]
) )
group_platform = group_platform_class(None) group_platform = group_platform_class(ProcessContext(), None)
res1 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])]) res1 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])])
assert len(res1) == 0 assert len(res1) == 0
res2 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])]) res2 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])])

View File

@ -14,8 +14,9 @@ def test_cases():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_filter_user_custom_tag(app: App, test_cases): async def test_filter_user_custom_tag(app: App, test_cases):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils import ProcessContext
bilibili = platform_manager["bilibili"](AsyncClient()) bilibili = platform_manager["bilibili"](ProcessContext(), AsyncClient())
for case in test_cases: for case in test_cases:
res = bilibili.is_banned_post(**case["case"]) res = bilibili.is_banned_post(**case["case"])
assert res == case["result"] assert res == case["result"]
@ -25,8 +26,9 @@ async def test_filter_user_custom_tag(app: App, test_cases):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tag_separator(app: App): async def test_tag_separator(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils import ProcessContext
bilibili = platform_manager["bilibili"](AsyncClient()) bilibili = platform_manager["bilibili"](ProcessContext(), AsyncClient())
tags = ["~111", "222", "333", "~444", "555"] tags = ["~111", "222", "333", "~444", "555"]
res = bilibili.tag_separator(tags) res = bilibili.tag_separator(tags)
assert res[0] == ["222", "333", "555"] assert res[0] == ["222", "333", "555"]

View File

@ -21,8 +21,9 @@ image_cdn_router = respx.route(
@pytest.fixture @pytest.fixture
def weibo(app: App): def weibo(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.utils import ProcessContext
return platform_manager["weibo"](AsyncClient()) return platform_manager["weibo"](ProcessContext(), AsyncClient())
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

20
tests/test_context.py Normal file
View File

@ -0,0 +1,20 @@
import httpx
import respx
from nonebug.app import App
@respx.mock
async def test_http_error(app: App):
from nonebot_bison.utils import ProcessContext, http_client
example_route = respx.get("https://example.com")
example_route.mock(httpx.Response(403, json={"error": "gg"}))
ctx = ProcessContext()
async with http_client() as client:
ctx.register_to_client(client)
await client.get("https://example.com")
assert ctx.gen_req_records() == [
"https://example.com Headers({'host': 'example.com', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.103 Safari/537.36'}) | [403] Headers({'content-length': '15', 'content-type': 'application/json'}) {\"error\": \"gg\"}"
]