diff --git a/CHANGELOG.md b/CHANGELOG.md index 1182ba7..f4e7220 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,9 +6,14 @@ ### 新功能 +- feat: 临时解决 bilibili 的反爬机制 [@felinae98](https://github.com/felinae98) ([#110](https://github.com/felinae98/nonebot-bison/pull/110)) - 在StatusChange中提供了如果api返回错误不更新status的方法 [@felinae98](https://github.com/felinae98) ([#96](https://github.com/felinae98/nonebot-bison/pull/96)) - 添加 CustomPost [@felinae98](https://github.com/felinae98) ([#81](https://github.com/felinae98/nonebot-bison/pull/81)) +### Bug 修复 + +- fix: 修复 bilibili-live 中获取状态错误后产生的错误行为 [@felinae98](https://github.com/felinae98) ([#111](https://github.com/felinae98/nonebot-bison/pull/111)) + ## v0.5.4 ### 新功能 diff --git a/src/plugins/nonebot_bison/platform/bilibili.py b/src/plugins/nonebot_bison/platform/bilibili.py index 3903c0c..31378d1 100644 --- a/src/plugins/nonebot_bison/platform/bilibili.py +++ b/src/plugins/nonebot_bison/platform/bilibili.py @@ -1,10 +1,16 @@ +import functools import json import re -from typing import Any, Optional +from datetime import datetime, timedelta +from typing import Any, Callable, Optional + +import httpx +from nonebot.log import logger from ..post import Post from ..types import Category, RawPost, Tag, Target -from ..utils import SchedulerConfig, http_client +from ..utils import SchedulerConfig +from ..utils.http import http_args from .platform import CategoryNotSupport, NewMessage, StatusChange @@ -14,7 +20,37 @@ class BilibiliSchedConf(SchedulerConfig, name="bilibili.com"): schedule_setting = {"seconds": 10} -class Bilibili(NewMessage): +from .platform import CategoryNotSupport, NewMessage, StatusChange + + +class _BilibiliClient: + + _http_client: httpx.AsyncClient + _client_refresh_time: Optional[datetime] + cookie_expire_time = timedelta(hours=5) + + async def _init_session(self): + self._http_client = httpx.AsyncClient(**http_args) + res = await self._http_client.get("https://www.bilibili.com/") + if res.status_code != 200: + import ipdb + + ipdb.set_trace() + logger.warning("unable to refresh temp cookie") + else: + self._client_refresh_time = datetime.now() + + async def _refresh_client(self): + if ( + getattr(self, "_client_refresh_time", None) is None + or datetime.now() - self._client_refresh_time + > self.cookie_expire_time # type:ignore + or self._http_client is None + ): + await self._init_session() + + +class Bilibili(_BilibiliClient, NewMessage): categories = { 1: "一般动态", @@ -33,15 +69,23 @@ class Bilibili(NewMessage): has_target = True parse_target_promot = "请输入用户主页的链接" + def ensure_client(fun: Callable): # type:ignore + @functools.wraps(fun) + async def wrapped(self, *args, **kwargs): + await self._refresh_client() + return await fun(self, *args, **kwargs) + + return wrapped + + @ensure_client async def get_target_name(self, target: Target) -> Optional[str]: - async with http_client() as client: - res = await client.get( - "https://api.bilibili.com/x/space/acc/info", params={"mid": target} - ) - res_data = json.loads(res.text) - if res_data["code"]: - return None - return res_data["data"]["name"] + res = await self._http_client.get( + "https://api.bilibili.com/x/space/acc/info", params={"mid": target} + ) + res_data = json.loads(res.text) + if res_data["code"]: + return None + return res_data["data"]["name"] async def parse_target(self, target_text: str) -> Target: if re.match(r"\d+", target_text): @@ -51,19 +95,19 @@ class Bilibili(NewMessage): else: raise self.ParseTargetException() + @ensure_client async def get_sub_list(self, target: Target) -> list[RawPost]: - async with http_client() as client: - params = {"host_uid": target, "offset": 0, "need_top": 0} - res = await client.get( - "https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history", - params=params, - timeout=4.0, - ) - res_dict = json.loads(res.text) - if res_dict["code"] == 0: - return res_dict["data"].get("cards") - else: - return [] + params = {"host_uid": target, "offset": 0, "need_top": 0} + res = await self._http_client.get( + "https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history", + params=params, + timeout=4.0, + ) + res_dict = json.loads(res.text) + if res_dict["code"] == 0: + return res_dict["data"].get("cards") + else: + return [] def get_id(self, post: RawPost) -> Any: return post["desc"]["dynamic_id"] @@ -160,7 +204,7 @@ class Bilibili(NewMessage): return Post("bilibili", text=text, url=url, pics=pic, target_name=target_name) -class Bilibililive(StatusChange): +class Bilibililive(_BilibiliClient, StatusChange): # Author : Sichongzou # Date : 2022-5-18 8:54 # Description : bilibili开播提醒 @@ -174,36 +218,44 @@ class Bilibililive(StatusChange): name = "Bilibili直播" has_target = True - async def get_target_name(self, target: Target) -> Optional[str]: - async with http_client() as client: - res = await client.get( - "https://api.bilibili.com/x/space/acc/info", params={"mid": target} - ) - res_data = json.loads(res.text) - if res_data["code"]: - return None - return res_data["data"]["name"] + def ensure_client(fun: Callable): # type:ignore + @functools.wraps(fun) + async def wrapped(self, *args, **kwargs): + await self._refresh_client() + return await fun(self, *args, **kwargs) + return wrapped + + @ensure_client + async def get_target_name(self, target: Target) -> Optional[str]: + res = await self._http_client.get( + "https://api.bilibili.com/x/space/acc/info", params={"mid": target} + ) + res_data = json.loads(res.text) + if res_data["code"]: + return None + return res_data["data"]["name"] + + @ensure_client async def get_status(self, target: Target): - async with http_client() as client: - params = {"mid": target} - res = await client.get( - "https://api.bilibili.com/x/space/acc/info", - params=params, - timeout=4.0, - ) - res_dict = json.loads(res.text) - if res_dict["code"] == 0: - info = {} - info["uid"] = res_dict["data"]["mid"] - info["uname"] = res_dict["data"]["name"] - info["live_state"] = res_dict["data"]["live_room"]["liveStatus"] - info["room_id"] = res_dict["data"]["live_room"]["roomid"] - info["title"] = res_dict["data"]["live_room"]["title"] - info["cover"] = res_dict["data"]["live_room"]["cover"] - return info - else: - raise self.ParseTargetException(res.text) + params = {"mid": target} + res = await self._http_client.get( + "https://api.bilibili.com/x/space/acc/info", + params=params, + timeout=4.0, + ) + res_dict = json.loads(res.text) + if res_dict["code"] == 0: + info = {} + info["uid"] = res_dict["data"]["mid"] + info["uname"] = res_dict["data"]["name"] + info["live_state"] = res_dict["data"]["live_room"]["liveStatus"] + info["room_id"] = res_dict["data"]["live_room"]["roomid"] + info["title"] = res_dict["data"]["live_room"]["title"] + info["cover"] = res_dict["data"]["live_room"]["cover"] + return info + else: + raise self.FetchError() def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]: if ( @@ -229,7 +281,7 @@ class Bilibililive(StatusChange): ) -class BilibiliBangumi(StatusChange): +class BilibiliBangumi(_BilibiliClient, StatusChange): categories = {} platform_name = "bilibili-bangumi" @@ -243,13 +295,21 @@ class BilibiliBangumi(StatusChange): _url = "https://api.bilibili.com/pgc/review/user" + def ensure_client(fun: Callable): # type:ignore + @functools.wraps(fun) + async def wrapped(self, *args, **kwargs): + await self._refresh_client() + return await fun(self, *args, **kwargs) + + return wrapped + + @ensure_client async def get_target_name(self, target: Target) -> Optional[str]: - async with http_client() as client: - res = await client.get(self._url, params={"media_id": target}) - res_data = res.json() - if res_data["code"]: - return None - return res_data["result"]["media"]["title"] + res = await self._http_client.get(self._url, params={"media_id": target}) + res_data = res.json() + if res_data["code"]: + return None + return res_data["result"]["media"]["title"] async def parse_target(self, target_string: str) -> Target: if re.match(r"\d+", target_string): @@ -262,22 +322,22 @@ class BilibiliBangumi(StatusChange): return Target(m.group(1)) raise self.ParseTargetException() + @ensure_client async def get_status(self, target: Target): - async with http_client() as client: - res = await client.get( - self._url, - params={"media_id": target}, - timeout=4.0, - ) - res_dict = res.json() - if res_dict["code"] == 0: - return { - "index": res_dict["result"]["media"]["new_ep"]["index"], - "index_show": res_dict["result"]["media"]["new_ep"]["index"], - "season_id": res_dict["result"]["media"]["season_id"], - } - else: - raise self.FetchError + res = await self._http_client.get( + self._url, + params={"media_id": target}, + timeout=4.0, + ) + res_dict = res.json() + if res_dict["code"] == 0: + return { + "index": res_dict["result"]["media"]["new_ep"]["index"], + "index_show": res_dict["result"]["media"]["new_ep"]["index"], + "season_id": res_dict["result"]["media"]["season_id"], + } + else: + raise self.FetchError def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]: if new_status["index"] != old_status["index"]: @@ -285,11 +345,11 @@ class BilibiliBangumi(StatusChange): else: return [] + @ensure_client async def parse(self, raw_post: RawPost) -> Post: - async with http_client() as client: - detail_res = await client.get( - f'http://api.bilibili.com/pgc/view/web/season?season_id={raw_post["season_id"]}' - ) + detail_res = await self._http_client.get( + f'http://api.bilibili.com/pgc/view/web/season?season_id={raw_post["season_id"]}' + ) detail_dict = detail_res.json() lastest_episode = None for episode in detail_dict["result"]["episodes"][::-1]: diff --git a/src/plugins/nonebot_bison/utils/http.py b/src/plugins/nonebot_bison/utils/http.py index 082aa55..9dd80d8 100644 --- a/src/plugins/nonebot_bison/utils/http.py +++ b/src/plugins/nonebot_bison/utils/http.py @@ -4,8 +4,9 @@ import httpx from ..plugin_config import plugin_config -http_client = functools.partial( - httpx.AsyncClient, - proxies=plugin_config.bison_proxy or None, - headers={"user-agent": plugin_config.bison_ua}, -) +http_args = { + "proxies": plugin_config.bison_proxy or None, + "headers": {"user-agent": plugin_config.bison_ua}, +} + +http_client = functools.partial(httpx.AsyncClient, **http_args) diff --git a/tests/platforms/test_bilibili_live.py b/tests/platforms/test_bilibili_live.py index eb7cbb6..b369ac3 100644 --- a/tests/platforms/test_bilibili_live.py +++ b/tests/platforms/test_bilibili_live.py @@ -22,6 +22,10 @@ async def test_fetch_bilibili_live_status(bili_live, dummy_user_subinfo): "https://api.bilibili.com/x/space/acc/info?mid=13164144" ) bili_live_router.mock(return_value=Response(200, json=mock_bili_live_status)) + + bilibili_main_page_router = respx.get("https://www.bilibili.com/") + bilibili_main_page_router.mock(return_value=Response(200)) + target = "13164144" res = await bili_live.fetch_new_post(target, [dummy_user_subinfo]) assert bili_live_router.called diff --git a/tests/test_config_manager_add.py b/tests/test_config_manager_add.py index 4dc2b1f..383fe51 100644 --- a/tests/test_config_manager_add.py +++ b/tests/test_config_manager_add.py @@ -422,6 +422,9 @@ async def test_add_with_bilibili_target_parser(app: App, init_scheduler): return_value=Response(200, json=get_json("bilibili_arknights_profile.json")) ) + bilibili_main_page_router = respx.get("https://www.bilibili.com/") + bilibili_main_page_router.mock(return_value=Response(200)) + async with app.test_matcher(add_sub_matcher) as ctx: bot = ctx.create_bot() event_1 = fake_group_message_event(