diff --git a/src/plugins/nonebot_bison/platform/bilibili.py b/src/plugins/nonebot_bison/platform/bilibili.py index b18851e..2ec613e 100644 --- a/src/plugins/nonebot_bison/platform/bilibili.py +++ b/src/plugins/nonebot_bison/platform/bilibili.py @@ -1,14 +1,45 @@ +import functools import json import re +from datetime import datetime, timedelta from typing import Any, Optional +import httpx +from nonebot.log import logger + from ..post import Post from ..types import Category, RawPost, Tag, Target -from ..utils import http_client +from ..utils.http import http_args from .platform import CategoryNotSupport, NewMessage, StatusChange -class Bilibili(NewMessage): +class _BilibiliClient: + + _http_client: httpx.AsyncClient + _client_refresh_time: Optional[datetime] + cookie_expire_time = timedelta(hours=5) + + async def _init_session(self): + self._http_client = httpx.AsyncClient(**http_args) + res = await self._http_client.get("https://www.bilibili.com/") + if res.status_code != 200: + import ipdb + + ipdb.set_trace() + logger.warning("unable to refresh temp cookie") + else: + self._client_refresh_time = datetime.now() + + async def _refresh_client(self): + if ( + getattr(self, "_client_refresh_time", None) is None + or datetime.now() - self._client_refresh_time > self.cookie_expire_time + or self._http_client is None + ): + await self._init_session() + + +class Bilibili(_BilibiliClient, NewMessage): categories = { 1: "一般动态", @@ -28,15 +59,23 @@ class Bilibili(NewMessage): has_target = True parse_target_promot = "请输入用户主页的链接" + def ensure_client(fun): + @functools.wraps(fun) + async def wrapped(self, *args, **kwargs): + await self._refresh_client() + return await fun(self, *args, **kwargs) + + return wrapped + + @ensure_client async def get_target_name(self, target: Target) -> Optional[str]: - async with http_client() as client: - res = await client.get( - "https://api.bilibili.com/x/space/acc/info", params={"mid": target} - ) - res_data = json.loads(res.text) - if res_data["code"]: - return None - return res_data["data"]["name"] + res = await self._http_client.get( + "https://api.bilibili.com/x/space/acc/info", params={"mid": target} + ) + res_data = json.loads(res.text) + if res_data["code"]: + return None + return res_data["data"]["name"] async def parse_target(self, target_text: str) -> Target: if re.match(r"\d+", target_text): @@ -48,19 +87,19 @@ class Bilibili(NewMessage): else: raise self.ParseTargetException() + @ensure_client async def get_sub_list(self, target: Target) -> list[RawPost]: - async with http_client() as client: - params = {"host_uid": target, "offset": 0, "need_top": 0} - res = await client.get( - "https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history", - params=params, - timeout=4.0, - ) - res_dict = json.loads(res.text) - if res_dict["code"] == 0: - return res_dict["data"].get("cards") - else: - return [] + params = {"host_uid": target, "offset": 0, "need_top": 0} + res = await self._http_client.get( + "https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history", + params=params, + timeout=4.0, + ) + res_dict = json.loads(res.text) + if res_dict["code"] == 0: + return res_dict["data"].get("cards") + else: + return [] def get_id(self, post: RawPost) -> Any: return post["desc"]["dynamic_id"] @@ -157,7 +196,7 @@ class Bilibili(NewMessage): return Post("bilibili", text=text, url=url, pics=pic, target_name=target_name) -class Bilibililive(StatusChange): +class Bilibililive(_BilibiliClient, StatusChange): # Author : Sichongzou # Date : 2022-5-18 8:54 # Description : bilibili开播提醒 @@ -172,36 +211,44 @@ class Bilibililive(StatusChange): name = "Bilibili直播" has_target = True - async def get_target_name(self, target: Target) -> Optional[str]: - async with http_client() as client: - res = await client.get( - "https://api.bilibili.com/x/space/acc/info", params={"mid": target} - ) - res_data = json.loads(res.text) - if res_data["code"]: - return None - return res_data["data"]["name"] + def ensure_client(fun): + @functools.wraps(fun) + async def wrapped(self, *args, **kwargs): + await self._refresh_client() + return await fun(self, *args, **kwargs) + return wrapped + + @ensure_client + async def get_target_name(self, target: Target) -> Optional[str]: + res = await self._http_client.get( + "https://api.bilibili.com/x/space/acc/info", params={"mid": target} + ) + res_data = json.loads(res.text) + if res_data["code"]: + return None + return res_data["data"]["name"] + + @ensure_client async def get_status(self, target: Target): - async with http_client() as client: - params = {"mid": target} - res = await client.get( - "https://api.bilibili.com/x/space/acc/info", - params=params, - timeout=4.0, - ) - res_dict = json.loads(res.text) - if res_dict["code"] == 0: - info = {} - info["uid"] = res_dict["data"]["mid"] - info["uname"] = res_dict["data"]["name"] - info["live_state"] = res_dict["data"]["live_room"]["liveStatus"] - info["room_id"] = res_dict["data"]["live_room"]["roomid"] - info["title"] = res_dict["data"]["live_room"]["title"] - info["cover"] = res_dict["data"]["live_room"]["cover"] - return info - else: - raise self.FetchError() + params = {"mid": target} + res = await self._http_client.get( + "https://api.bilibili.com/x/space/acc/info", + params=params, + timeout=4.0, + ) + res_dict = json.loads(res.text) + if res_dict["code"] == 0: + info = {} + info["uid"] = res_dict["data"]["mid"] + info["uname"] = res_dict["data"]["name"] + info["live_state"] = res_dict["data"]["live_room"]["liveStatus"] + info["room_id"] = res_dict["data"]["live_room"]["roomid"] + info["title"] = res_dict["data"]["live_room"]["title"] + info["cover"] = res_dict["data"]["live_room"]["cover"] + return info + else: + raise self.FetchError() def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]: if ( diff --git a/src/plugins/nonebot_bison/utils/http.py b/src/plugins/nonebot_bison/utils/http.py index 082aa55..9dd80d8 100644 --- a/src/plugins/nonebot_bison/utils/http.py +++ b/src/plugins/nonebot_bison/utils/http.py @@ -4,8 +4,9 @@ import httpx from ..plugin_config import plugin_config -http_client = functools.partial( - httpx.AsyncClient, - proxies=plugin_config.bison_proxy or None, - headers={"user-agent": plugin_config.bison_ua}, -) +http_args = { + "proxies": plugin_config.bison_proxy or None, + "headers": {"user-agent": plugin_config.bison_ua}, +} + +http_client = functools.partial(httpx.AsyncClient, **http_args) diff --git a/tests/platforms/test_bilibili_live.py b/tests/platforms/test_bilibili_live.py index 53885ef..0867a80 100644 --- a/tests/platforms/test_bilibili_live.py +++ b/tests/platforms/test_bilibili_live.py @@ -26,6 +26,10 @@ async def test_fetch_bilibili_live_status(bili_live, dummy_user_subinfo): "https://api.bilibili.com/x/space/acc/info?mid=13164144" ) bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) + + bilibili_main_page_router = respx.get("https://www.bilibili.com/") + bilibili_main_page_router.mock(return_value=Response(200)) + target = "13164144" res = await bili_live.fetch_new_post(target, [dummy_user_subinfo]) assert bili_live_router.called diff --git a/tests/test_config_manager_add.py b/tests/test_config_manager_add.py index dcf63e5..d087f07 100644 --- a/tests/test_config_manager_add.py +++ b/tests/test_config_manager_add.py @@ -440,6 +440,9 @@ async def test_add_with_bilibili_target_parser(app: App): return_value=Response(200, json=get_json("bilibili_arknights_profile.json")) ) + bilibili_main_page_router = respx.get("https://www.bilibili.com/") + bilibili_main_page_router.mock(return_value=Response(200)) + async with app.test_matcher(add_sub_matcher) as ctx: bot = ctx.create_bot() event_1 = fake_group_message_event(