mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-07-16 21:53:01 +08:00
✨ add context to log http error
This commit is contained in:
parent
aa810cc903
commit
bd679914eb
@ -1,21 +1,17 @@
|
||||
import functools
|
||||
import json
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from httpx import AsyncClient
|
||||
from nonebot.log import logger
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..plugin_config import plugin_config
|
||||
from ..post import Post
|
||||
from ..types import Category, RawPost, Tag, Target
|
||||
from ..types import ApiError, Category, RawPost, Tag, Target
|
||||
from ..utils import SchedulerConfig
|
||||
from ..utils.http import http_args
|
||||
from .platform import CategoryNotSupport, NewMessage, StatusChange
|
||||
|
||||
|
||||
@ -105,7 +101,7 @@ class Bilibili(NewMessage):
|
||||
if res_dict["code"] == 0:
|
||||
return res_dict["data"].get("cards")
|
||||
else:
|
||||
return []
|
||||
raise ApiError(res.request.url)
|
||||
|
||||
def get_id(self, post: RawPost) -> Any:
|
||||
return post["desc"]["dynamic_id"]
|
||||
@ -306,7 +302,7 @@ class Bilibililive(StatusChange):
|
||||
self.name,
|
||||
text=title,
|
||||
url=url,
|
||||
pics=pic,
|
||||
pics=list(pic),
|
||||
target_name=target_name,
|
||||
compress=True,
|
||||
)
|
||||
@ -384,14 +380,14 @@ class BilibiliBangumi(StatusChange):
|
||||
lastest_episode = detail_dict["result"]["episodes"]
|
||||
|
||||
url = lastest_episode["link"]
|
||||
pic = [lastest_episode["cover"]]
|
||||
pic: list[str] = [lastest_episode["cover"]]
|
||||
target_name = detail_dict["result"]["season_title"]
|
||||
text = lastest_episode["share_copy"]
|
||||
return Post(
|
||||
self.name,
|
||||
text=text,
|
||||
url=url,
|
||||
pics=pic,
|
||||
pics=list(pic),
|
||||
target_name=target_name,
|
||||
compress=True,
|
||||
)
|
||||
|
@ -4,7 +4,7 @@ from httpx import AsyncClient
|
||||
|
||||
from ..post import Post
|
||||
from ..types import RawPost, Target
|
||||
from ..utils import http_client, scheduler
|
||||
from ..utils import scheduler
|
||||
from .platform import NewMessage
|
||||
|
||||
|
||||
@ -27,11 +27,10 @@ class FF14(NewMessage):
|
||||
return "最终幻想XIV官方公告"
|
||||
|
||||
async def get_sub_list(self, _) -> list[RawPost]:
|
||||
async with http_client() as client:
|
||||
raw_data = await client.get(
|
||||
"https://ff.web.sdo.com/inc/newdata.ashx?url=List?gameCode=ff&category=5309,5310,5311,5312,5313&pageIndex=0&pageSize=5"
|
||||
)
|
||||
return raw_data.json()["Data"]
|
||||
raw_data = await self.client.get(
|
||||
"https://ff.web.sdo.com/inc/newdata.ashx?url=List?gameCode=ff&category=5309,5310,5311,5312,5313&pageIndex=0&pageSize=5"
|
||||
)
|
||||
return raw_data.json()["Data"]
|
||||
|
||||
def get_id(self, post: RawPost) -> Any:
|
||||
"""用发布时间当作 ID
|
||||
|
@ -4,8 +4,8 @@ from typing import Any, Optional
|
||||
from httpx import AsyncClient
|
||||
|
||||
from ..post import Post
|
||||
from ..types import RawPost, Target
|
||||
from ..utils import SchedulerConfig, http_client
|
||||
from ..types import ApiError, RawPost, Target
|
||||
from ..utils import SchedulerConfig
|
||||
from .platform import NewMessage
|
||||
|
||||
|
||||
@ -32,15 +32,14 @@ class NcmArtist(NewMessage):
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
async with http_client() as client:
|
||||
res = await client.get(
|
||||
"https://music.163.com/api/artist/albums/{}".format(target),
|
||||
headers={"Referer": "https://music.163.com/"},
|
||||
)
|
||||
res_data = res.json()
|
||||
if res_data["code"] != 200:
|
||||
return
|
||||
return res_data["artist"]["name"]
|
||||
res = await client.get(
|
||||
"https://music.163.com/api/artist/albums/{}".format(target),
|
||||
headers={"Referer": "https://music.163.com/"},
|
||||
)
|
||||
res_data = res.json()
|
||||
if res_data["code"] != 200:
|
||||
raise ApiError(res.request.url)
|
||||
return res_data["artist"]["name"]
|
||||
|
||||
@classmethod
|
||||
async def parse_target(cls, target_text: str) -> Target:
|
||||
@ -54,16 +53,15 @@ class NcmArtist(NewMessage):
|
||||
raise cls.ParseTargetException()
|
||||
|
||||
async def get_sub_list(self, target: Target) -> list[RawPost]:
|
||||
async with http_client() as client:
|
||||
res = await client.get(
|
||||
"https://music.163.com/api/artist/albums/{}".format(target),
|
||||
headers={"Referer": "https://music.163.com/"},
|
||||
)
|
||||
res_data = res.json()
|
||||
if res_data["code"] != 200:
|
||||
return []
|
||||
else:
|
||||
return res_data["hotAlbums"]
|
||||
res = await self.client.get(
|
||||
"https://music.163.com/api/artist/albums/{}".format(target),
|
||||
headers={"Referer": "https://music.163.com/"},
|
||||
)
|
||||
res_data = res.json()
|
||||
if res_data["code"] != 200:
|
||||
return []
|
||||
else:
|
||||
return res_data["hotAlbums"]
|
||||
|
||||
def get_id(self, post: RawPost) -> Any:
|
||||
return post["id"]
|
||||
@ -97,16 +95,15 @@ class NcmRadio(NewMessage):
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
async with http_client() as client:
|
||||
res = await client.post(
|
||||
"http://music.163.com/api/dj/program/byradio",
|
||||
headers={"Referer": "https://music.163.com/"},
|
||||
data={"radioId": target, "limit": 1000, "offset": 0},
|
||||
)
|
||||
res_data = res.json()
|
||||
if res_data["code"] != 200 or res_data["programs"] == 0:
|
||||
return
|
||||
return res_data["programs"][0]["radio"]["name"]
|
||||
res = await client.post(
|
||||
"http://music.163.com/api/dj/program/byradio",
|
||||
headers={"Referer": "https://music.163.com/"},
|
||||
data={"radioId": target, "limit": 1000, "offset": 0},
|
||||
)
|
||||
res_data = res.json()
|
||||
if res_data["code"] != 200 or res_data["programs"] == 0:
|
||||
return
|
||||
return res_data["programs"][0]["radio"]["name"]
|
||||
|
||||
@classmethod
|
||||
async def parse_target(cls, target_text: str) -> Target:
|
||||
@ -120,17 +117,16 @@ class NcmRadio(NewMessage):
|
||||
raise cls.ParseTargetException()
|
||||
|
||||
async def get_sub_list(self, target: Target) -> list[RawPost]:
|
||||
async with http_client() as client:
|
||||
res = await client.post(
|
||||
"http://music.163.com/api/dj/program/byradio",
|
||||
headers={"Referer": "https://music.163.com/"},
|
||||
data={"radioId": target, "limit": 1000, "offset": 0},
|
||||
)
|
||||
res_data = res.json()
|
||||
if res_data["code"] != 200:
|
||||
return []
|
||||
else:
|
||||
return res_data["programs"]
|
||||
res = await self.client.post(
|
||||
"http://music.163.com/api/dj/program/byradio",
|
||||
headers={"Referer": "https://music.163.com/"},
|
||||
data={"radioId": target, "limit": 1000, "offset": 0},
|
||||
)
|
||||
res_data = res.json()
|
||||
if res_data["code"] != 200:
|
||||
return []
|
||||
else:
|
||||
return res_data["programs"]
|
||||
|
||||
def get_id(self, post: RawPost) -> Any:
|
||||
return post["id"]
|
||||
|
@ -14,7 +14,7 @@ from nonebot.log import logger
|
||||
from ..plugin_config import plugin_config
|
||||
from ..post import Post
|
||||
from ..types import Category, RawPost, Tag, Target, User, UserSubInfo
|
||||
from ..utils.scheduler_config import SchedulerConfig
|
||||
from ..utils import ProcessContext, SchedulerConfig
|
||||
|
||||
|
||||
class CategoryNotSupport(Exception):
|
||||
@ -57,6 +57,7 @@ class PlatformABCMeta(PlatformMeta, ABC):
|
||||
class Platform(metaclass=PlatformABCMeta, base=True):
|
||||
|
||||
scheduler: Type[SchedulerConfig]
|
||||
ctx: ProcessContext
|
||||
is_common: bool
|
||||
enabled: bool
|
||||
name: str
|
||||
@ -99,7 +100,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
||||
return []
|
||||
except json.JSONDecodeError as err:
|
||||
logger.warning(f"json error, parsing: {err.doc}")
|
||||
return []
|
||||
raise err
|
||||
|
||||
@abstractmethod
|
||||
async def parse(self, raw_post: RawPost) -> Post:
|
||||
@ -109,9 +110,10 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
||||
"actually function called"
|
||||
return await self.parse(raw_post)
|
||||
|
||||
def __init__(self, client: AsyncClient):
|
||||
def __init__(self, context: ProcessContext, client: AsyncClient):
|
||||
super().__init__()
|
||||
self.client = client
|
||||
self.ctx = context
|
||||
|
||||
class ParseTargetException(Exception):
|
||||
pass
|
||||
@ -209,8 +211,8 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
||||
class MessageProcess(Platform, abstract=True):
|
||||
"General message process fetch, parse, filter progress"
|
||||
|
||||
def __init__(self, client: AsyncClient):
|
||||
super().__init__(client)
|
||||
def __init__(self, ctx: ProcessContext, client: AsyncClient):
|
||||
super().__init__(ctx, client)
|
||||
self.parse_cache: dict[Any, Post] = dict()
|
||||
|
||||
@abstractmethod
|
||||
@ -254,6 +256,9 @@ class MessageProcess(Platform, abstract=True):
|
||||
try:
|
||||
self.get_category(raw_post)
|
||||
except CategoryNotSupport:
|
||||
msgs = self.ctx.gen_req_records()
|
||||
for m in msgs:
|
||||
logger.warning(m)
|
||||
continue
|
||||
except NotImplementedError:
|
||||
pass
|
||||
@ -342,7 +347,7 @@ class StatusChange(Platform, abstract=True):
|
||||
new_status = await self.get_status(target)
|
||||
except self.FetchError as err:
|
||||
logger.warning(f"fetching {self.name}-{target} error: {err}")
|
||||
return []
|
||||
raise
|
||||
res = []
|
||||
if old_status := self.get_stored_data(target):
|
||||
diff = self.compare_status(target, old_status, new_status)
|
||||
@ -420,11 +425,11 @@ def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
|
||||
"Platform scheduler for {} not fit".format(platform_name)
|
||||
)
|
||||
|
||||
def __init__(self: "NoTargetGroup", client: AsyncClient):
|
||||
Platform.__init__(self, client)
|
||||
def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient):
|
||||
Platform.__init__(self, ctx, client)
|
||||
self.platform_obj_list = []
|
||||
for platform_class in self.platform_list:
|
||||
self.platform_obj_list.append(platform_class(client))
|
||||
self.platform_obj_list.append(platform_class(ctx, client))
|
||||
|
||||
def __str__(self: "NoTargetGroup") -> str:
|
||||
return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]"
|
||||
|
@ -7,7 +7,7 @@ from httpx import AsyncClient
|
||||
|
||||
from ..post import Post
|
||||
from ..types import RawPost, Target
|
||||
from ..utils import http_client, scheduler
|
||||
from ..utils import scheduler
|
||||
from .platform import NewMessage
|
||||
|
||||
|
||||
@ -26,10 +26,9 @@ class Rss(NewMessage):
|
||||
async def get_target_name(
|
||||
cls, client: AsyncClient, target: Target
|
||||
) -> Optional[str]:
|
||||
async with http_client() as client:
|
||||
res = await client.get(target, timeout=10.0)
|
||||
feed = feedparser.parse(res.text)
|
||||
return feed["feed"]["title"]
|
||||
res = await client.get(target, timeout=10.0)
|
||||
feed = feedparser.parse(res.text)
|
||||
return feed["feed"]["title"]
|
||||
|
||||
def get_date(self, post: RawPost) -> int:
|
||||
return calendar.timegm(post.published_parsed)
|
||||
@ -38,13 +37,12 @@ class Rss(NewMessage):
|
||||
return post.id
|
||||
|
||||
async def get_sub_list(self, target: Target) -> list[RawPost]:
|
||||
async with http_client() as client:
|
||||
res = await client.get(target, timeout=10.0)
|
||||
feed = feedparser.parse(res)
|
||||
entries = feed.entries
|
||||
for entry in entries:
|
||||
entry["_target_name"] = feed.feed.title
|
||||
return feed.entries
|
||||
res = await self.client.get(target, timeout=10.0)
|
||||
feed = feedparser.parse(res)
|
||||
entries = feed.entries
|
||||
for entry in entries:
|
||||
entry["_target_name"] = feed.feed.title
|
||||
return feed.entries
|
||||
|
||||
async def parse(self, raw_post: RawPost) -> Post:
|
||||
text = raw_post.get("title", "") + "\n" if raw_post.get("title") else ""
|
||||
|
@ -68,7 +68,7 @@ class Weibo(NewMessage):
|
||||
)
|
||||
res_data = json.loads(res.text)
|
||||
if not res_data["ok"]:
|
||||
return []
|
||||
raise ApiError(res.request.url)
|
||||
custom_filter: Callable[[RawPost], bool] = lambda d: d["card_type"] == 9
|
||||
return list(filter(custom_filter, res_data["data"]["cards"]))
|
||||
|
||||
|
@ -9,7 +9,7 @@ from ..config import config
|
||||
from ..platform import platform_manager
|
||||
from ..send import send_msgs
|
||||
from ..types import Target
|
||||
from ..utils import SchedulerConfig
|
||||
from ..utils import ProcessContext, SchedulerConfig
|
||||
from .aps import aps
|
||||
|
||||
|
||||
@ -78,6 +78,7 @@ class Scheduler:
|
||||
return cur_max_schedulable
|
||||
|
||||
async def exec_fetch(self):
|
||||
context = ProcessContext()
|
||||
if not (schedulable := await self.get_next_schedulable()):
|
||||
return
|
||||
logger.debug(
|
||||
@ -86,12 +87,22 @@ class Scheduler:
|
||||
send_userinfo_list = await config.get_platform_target_subscribers(
|
||||
schedulable.platform_name, schedulable.target
|
||||
)
|
||||
platform_obj = platform_manager[schedulable.platform_name](
|
||||
await self.scheduler_config_obj.get_client(schedulable.target)
|
||||
)
|
||||
to_send = await platform_obj.do_fetch_new_post(
|
||||
schedulable.target, send_userinfo_list
|
||||
)
|
||||
|
||||
client = await self.scheduler_config_obj.get_client(schedulable.target)
|
||||
context.register_to_client(client)
|
||||
|
||||
try:
|
||||
platform_obj = platform_manager[schedulable.platform_name](context, client)
|
||||
to_send = await platform_obj.do_fetch_new_post(
|
||||
schedulable.target, send_userinfo_list
|
||||
)
|
||||
except Exception as err:
|
||||
records = context.gen_req_records()
|
||||
for record in records:
|
||||
logger.warning("API request record: " + record)
|
||||
err.args += (records,)
|
||||
raise
|
||||
|
||||
if not to_send:
|
||||
return
|
||||
bot = nonebot.get_bot()
|
||||
|
@ -2,6 +2,7 @@ from dataclasses import dataclass
|
||||
from datetime import time
|
||||
from typing import Any, Literal, NamedTuple, NewType
|
||||
|
||||
from httpx import URL
|
||||
from pydantic import BaseModel
|
||||
|
||||
RawPost = NewType("RawPost", Any)
|
||||
@ -45,3 +46,9 @@ class PlatformWeightConfigResp(BaseModel):
|
||||
target_name: str
|
||||
platform_name: str
|
||||
weight: WeightConfig
|
||||
|
||||
|
||||
class ApiError(Exception):
|
||||
def __init__(self, url: URL) -> None:
|
||||
msg = f"api {url} error"
|
||||
super().__init__(msg)
|
||||
|
@ -9,6 +9,7 @@ from nonebot.log import default_format, logger
|
||||
from nonebot.plugin import require
|
||||
|
||||
from ..plugin_config import plugin_config
|
||||
from .context import ProcessContext
|
||||
from .http import http_client
|
||||
from .scheduler_config import SchedulerConfig, scheduler
|
||||
|
||||
@ -16,6 +17,7 @@ __all__ = [
|
||||
"http_client",
|
||||
"Singleton",
|
||||
"parse_text",
|
||||
"ProcessContext",
|
||||
"html_to_text",
|
||||
"SchedulerConfig",
|
||||
"scheduler",
|
||||
|
40
src/plugins/nonebot_bison/utils/context.py
Normal file
40
src/plugins/nonebot_bison/utils/context.py
Normal file
@ -0,0 +1,40 @@
|
||||
from base64 import b64encode
|
||||
|
||||
from httpx import AsyncClient, Response
|
||||
|
||||
|
||||
class ProcessContext:
|
||||
reqs: list[Response]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.reqs = []
|
||||
|
||||
def log_response(self, resp: Response):
|
||||
self.reqs.append(resp)
|
||||
|
||||
def register_to_client(self, client: AsyncClient):
|
||||
async def _log_to_ctx(r: Response):
|
||||
self.log_response(r)
|
||||
|
||||
hooks = {
|
||||
"response": [_log_to_ctx],
|
||||
}
|
||||
client.event_hooks = hooks
|
||||
|
||||
def _should_print_content(self, r: Response) -> bool:
|
||||
content_type = r.headers["content-type"]
|
||||
if content_type.startswith("text"):
|
||||
return True
|
||||
if "json" in content_type:
|
||||
return True
|
||||
return False
|
||||
|
||||
def gen_req_records(self) -> list[str]:
|
||||
res = []
|
||||
for req in self.reqs:
|
||||
if self._should_print_content(req):
|
||||
log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} {req.text}"
|
||||
else:
|
||||
log_content = f"{req.request.url} {req.request.headers} | [{req.status_code}] {req.headers} b64encoded: {b64encode(req.content[:50]).decode()}"
|
||||
res.append(log_content)
|
||||
return res
|
@ -9,8 +9,9 @@ from .utils import get_file, get_json
|
||||
@pytest.fixture
|
||||
def arknights(app: App):
|
||||
from nonebot_bison.platform import platform_manager
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
return platform_manager["arknights"](AsyncClient())
|
||||
return platform_manager["arknights"](ProcessContext(), AsyncClient())
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -22,8 +22,9 @@ if typing.TYPE_CHECKING:
|
||||
@pytest.fixture
|
||||
def bilibili(app: App):
|
||||
from nonebot_bison.platform import platform_manager
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
return platform_manager["bilibili"](AsyncClient())
|
||||
return platform_manager["bilibili"](ProcessContext(), AsyncClient())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -14,8 +14,9 @@ if typing.TYPE_CHECKING:
|
||||
@pytest.fixture
|
||||
def bili_bangumi(app: App):
|
||||
from nonebot_bison.platform import platform_manager
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
return platform_manager["bilibili-bangumi"](AsyncClient())
|
||||
return platform_manager["bilibili-bangumi"](ProcessContext(), AsyncClient())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -9,8 +9,9 @@ from .utils import get_json
|
||||
@pytest.fixture
|
||||
def bili_live(app: App):
|
||||
from nonebot_bison.platform import platform_manager
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
return platform_manager["bilibili-live"](AsyncClient())
|
||||
return platform_manager["bilibili-live"](ProcessContext(), AsyncClient())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -9,8 +9,9 @@ from .utils import get_json
|
||||
@pytest.fixture
|
||||
def ff14(app: App):
|
||||
from nonebot_bison.platform import platform_manager
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
return platform_manager["ff14"](AsyncClient())
|
||||
return platform_manager["ff14"](ProcessContext(), AsyncClient())
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -9,8 +9,9 @@ from .utils import get_file, get_json
|
||||
@pytest.fixture
|
||||
def mcbbsnews(app: App):
|
||||
from nonebot_bison.platform import platform_manager
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
return platform_manager["mcbbsnews"](AsyncClient())
|
||||
return platform_manager["mcbbsnews"](ProcessContext(), AsyncClient())
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -15,8 +15,9 @@ if typing.TYPE_CHECKING:
|
||||
@pytest.fixture
|
||||
def ncm_artist(app: App):
|
||||
from nonebot_bison.platform import platform_manager
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
return platform_manager["ncm-artist"](AsyncClient())
|
||||
return platform_manager["ncm-artist"](ProcessContext(), AsyncClient())
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -15,8 +15,9 @@ if typing.TYPE_CHECKING:
|
||||
@pytest.fixture
|
||||
def ncm_radio(app: App):
|
||||
from nonebot_bison.platform import platform_manager
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
return platform_manager["ncm-radio"](AsyncClient())
|
||||
return platform_manager["ncm-radio"](ProcessContext(), AsyncClient())
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -1,10 +1,13 @@
|
||||
from time import time
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from nonebug.app import App
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot_bison.platform import Platform
|
||||
|
||||
now = time()
|
||||
passed = now - 3 * 60 * 60
|
||||
|
||||
@ -56,9 +59,6 @@ def mock_platform_without_cats_tags(app: App):
|
||||
|
||||
sub_index = 0
|
||||
|
||||
def __init__(self, client):
|
||||
super().__init__(client)
|
||||
|
||||
@classmethod
|
||||
async def get_target_name(cls, client, _: "Target"):
|
||||
return "MockPlatform"
|
||||
@ -117,9 +117,6 @@ def mock_platform(app: App):
|
||||
|
||||
sub_index = 0
|
||||
|
||||
def __init__(self, client):
|
||||
super().__init__(client)
|
||||
|
||||
@staticmethod
|
||||
async def get_target_name(_: "Target"):
|
||||
return "MockPlatform"
|
||||
@ -187,9 +184,6 @@ def mock_platform_no_target(app: App, mock_scheduler_conf):
|
||||
|
||||
sub_index = 0
|
||||
|
||||
def __init__(self, client):
|
||||
super().__init__(client)
|
||||
|
||||
@staticmethod
|
||||
async def get_target_name(_: "Target"):
|
||||
return "MockPlatform"
|
||||
@ -250,9 +244,6 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf):
|
||||
|
||||
sub_index = 0
|
||||
|
||||
def __init__(self, client):
|
||||
super().__init__(client)
|
||||
|
||||
@classmethod
|
||||
async def get_target_name(cls, client, _: "Target"):
|
||||
return "MockPlatform"
|
||||
@ -319,9 +310,6 @@ def mock_status_change(app: App):
|
||||
|
||||
sub_index = 0
|
||||
|
||||
def __init__(self, client):
|
||||
super().__init__(client)
|
||||
|
||||
@classmethod
|
||||
async def get_status(cls, _: "Target"):
|
||||
if cls.sub_index == 0:
|
||||
@ -353,11 +341,15 @@ def mock_status_change(app: App):
|
||||
async def test_new_message_target_without_cats_tags(
|
||||
mock_platform_without_cats_tags, user_info_factory
|
||||
):
|
||||
res1 = await mock_platform_without_cats_tags(AsyncClient()).fetch_new_post(
|
||||
"dummy", [user_info_factory([1, 2], [])]
|
||||
)
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
res1 = await mock_platform_without_cats_tags(
|
||||
ProcessContext(), AsyncClient()
|
||||
).fetch_new_post("dummy", [user_info_factory([1, 2], [])])
|
||||
assert len(res1) == 0
|
||||
res2 = await mock_platform_without_cats_tags(AsyncClient()).fetch_new_post(
|
||||
res2 = await mock_platform_without_cats_tags(
|
||||
ProcessContext(), AsyncClient()
|
||||
).fetch_new_post(
|
||||
"dummy",
|
||||
[
|
||||
user_info_factory([], []),
|
||||
@ -372,11 +364,13 @@ async def test_new_message_target_without_cats_tags(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_message_target(mock_platform, user_info_factory):
|
||||
res1 = await mock_platform(AsyncClient()).fetch_new_post(
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
res1 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post(
|
||||
"dummy", [user_info_factory([1, 2], [])]
|
||||
)
|
||||
assert len(res1) == 0
|
||||
res2 = await mock_platform(AsyncClient()).fetch_new_post(
|
||||
res2 = await mock_platform(ProcessContext(), AsyncClient()).fetch_new_post(
|
||||
"dummy",
|
||||
[
|
||||
user_info_factory([1, 2], []),
|
||||
@ -401,11 +395,15 @@ async def test_new_message_target(mock_platform, user_info_factory):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_message_no_target(mock_platform_no_target, user_info_factory):
|
||||
res1 = await mock_platform_no_target(AsyncClient()).fetch_new_post(
|
||||
"dummy", [user_info_factory([1, 2], [])]
|
||||
)
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
res1 = await mock_platform_no_target(
|
||||
ProcessContext(), AsyncClient()
|
||||
).fetch_new_post("dummy", [user_info_factory([1, 2], [])])
|
||||
assert len(res1) == 0
|
||||
res2 = await mock_platform_no_target(AsyncClient()).fetch_new_post(
|
||||
res2 = await mock_platform_no_target(
|
||||
ProcessContext(), AsyncClient()
|
||||
).fetch_new_post(
|
||||
"dummy",
|
||||
[
|
||||
user_info_factory([1, 2], []),
|
||||
@ -426,26 +424,28 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory)
|
||||
assert "p2" in id_set_1 and "p3" in id_set_1
|
||||
assert "p2" in id_set_2
|
||||
assert "p2" in id_set_3
|
||||
res3 = await mock_platform_no_target(AsyncClient()).fetch_new_post(
|
||||
"dummy", [user_info_factory([1, 2], [])]
|
||||
)
|
||||
res3 = await mock_platform_no_target(
|
||||
ProcessContext(), AsyncClient()
|
||||
).fetch_new_post("dummy", [user_info_factory([1, 2], [])])
|
||||
assert len(res3) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_change(mock_status_change, user_info_factory):
|
||||
res1 = await mock_status_change(AsyncClient()).fetch_new_post(
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
res1 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post(
|
||||
"dummy", [user_info_factory([1, 2], [])]
|
||||
)
|
||||
assert len(res1) == 0
|
||||
res2 = await mock_status_change(AsyncClient()).fetch_new_post(
|
||||
res2 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post(
|
||||
"dummy", [user_info_factory([1, 2], [])]
|
||||
)
|
||||
assert len(res2) == 1
|
||||
posts = res2[0][1]
|
||||
assert len(posts) == 1
|
||||
assert posts[0].text == "on"
|
||||
res3 = await mock_status_change(AsyncClient()).fetch_new_post(
|
||||
res3 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post(
|
||||
"dummy",
|
||||
[
|
||||
user_info_factory([1, 2], []),
|
||||
@ -456,7 +456,7 @@ async def test_status_change(mock_status_change, user_info_factory):
|
||||
assert len(res3[0][1]) == 1
|
||||
assert res3[0][1][0].text == "off"
|
||||
assert len(res3[1][1]) == 0
|
||||
res4 = await mock_status_change(AsyncClient()).fetch_new_post(
|
||||
res4 = await mock_status_change(ProcessContext(), AsyncClient()).fetch_new_post(
|
||||
"dummy", [user_info_factory([1, 2], [])]
|
||||
)
|
||||
assert len(res4) == 0
|
||||
@ -473,11 +473,12 @@ async def test_group(
|
||||
from nonebot_bison.platform.platform import make_no_target_group
|
||||
from nonebot_bison.post import Post
|
||||
from nonebot_bison.types import Category, RawPost, Tag, Target
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
group_platform_class = make_no_target_group(
|
||||
[mock_platform_no_target, mock_platform_no_target_2]
|
||||
)
|
||||
group_platform = group_platform_class(None)
|
||||
group_platform = group_platform_class(ProcessContext(), None)
|
||||
res1 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])])
|
||||
assert len(res1) == 0
|
||||
res2 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])])
|
||||
|
@ -14,8 +14,9 @@ def test_cases():
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_user_custom_tag(app: App, test_cases):
|
||||
from nonebot_bison.platform import platform_manager
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
bilibili = platform_manager["bilibili"](AsyncClient())
|
||||
bilibili = platform_manager["bilibili"](ProcessContext(), AsyncClient())
|
||||
for case in test_cases:
|
||||
res = bilibili.is_banned_post(**case["case"])
|
||||
assert res == case["result"]
|
||||
@ -25,8 +26,9 @@ async def test_filter_user_custom_tag(app: App, test_cases):
|
||||
@pytest.mark.asyncio
|
||||
async def test_tag_separator(app: App):
|
||||
from nonebot_bison.platform import platform_manager
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
bilibili = platform_manager["bilibili"](AsyncClient())
|
||||
bilibili = platform_manager["bilibili"](ProcessContext(), AsyncClient())
|
||||
tags = ["~111", "222", "333", "~444", "555"]
|
||||
res = bilibili.tag_separator(tags)
|
||||
assert res[0] == ["222", "333", "555"]
|
||||
|
@ -21,8 +21,9 @@ image_cdn_router = respx.route(
|
||||
@pytest.fixture
|
||||
def weibo(app: App):
|
||||
from nonebot_bison.platform import platform_manager
|
||||
from nonebot_bison.utils import ProcessContext
|
||||
|
||||
return platform_manager["weibo"](AsyncClient())
|
||||
return platform_manager["weibo"](ProcessContext(), AsyncClient())
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
20
tests/test_context.py
Normal file
20
tests/test_context.py
Normal file
@ -0,0 +1,20 @@
|
||||
import httpx
|
||||
import respx
|
||||
from nonebug.app import App
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_http_error(app: App):
|
||||
from nonebot_bison.utils import ProcessContext, http_client
|
||||
|
||||
example_route = respx.get("https://example.com")
|
||||
example_route.mock(httpx.Response(403, json={"error": "gg"}))
|
||||
|
||||
ctx = ProcessContext()
|
||||
async with http_client() as client:
|
||||
ctx.register_to_client(client)
|
||||
await client.get("https://example.com")
|
||||
|
||||
assert ctx.gen_req_records() == [
|
||||
"https://example.com Headers({'host': 'example.com', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/51.0.2704.103 Safari/537.36'}) | [403] Headers({'content-length': '15', 'content-type': 'application/json'}) {\"error\": \"gg\"}"
|
||||
]
|
Loading…
x
Reference in New Issue
Block a user