diff --git a/src/plugins/nonebot_bison/admin_page/api.py b/src/plugins/nonebot_bison/admin_page/api.py index 4e926f4..19ec3a5 100644 --- a/src/plugins/nonebot_bison/admin_page/api.py +++ b/src/plugins/nonebot_bison/admin_page/api.py @@ -1,6 +1,7 @@ import nonebot from nonebot.adapters.onebot.v11.bot import Bot +from ..apis import check_sub_target from ..config import ( NoSuchSubscribeException, NoSuchTargetException, @@ -8,7 +9,7 @@ from ..config import ( config, ) from ..config.db_config import SubscribeDupException -from ..platform import check_sub_target, platform_manager +from ..platform import platform_manager from ..types import Target as T_Target from ..types import WeightConfig from .jwt import pack_jwt diff --git a/src/plugins/nonebot_bison/apis.py b/src/plugins/nonebot_bison/apis.py new file mode 100644 index 0000000..872a7d5 --- /dev/null +++ b/src/plugins/nonebot_bison/apis.py @@ -0,0 +1,12 @@ +from .platform import platform_manager +from .scheduler import scheduler_dict +from .types import Target + + +async def check_sub_target(platform_name: str, target: Target): + platform = platform_manager[platform_name] + scheduler_conf_class = platform.scheduler + scheduler = scheduler_dict[scheduler_conf_class] + client = await scheduler.scheduler_config_obj.get_query_name_client() + + return await platform_manager[platform_name].get_target_name(client, target) diff --git a/src/plugins/nonebot_bison/config_manager.py b/src/plugins/nonebot_bison/config_manager.py index f00fa46..a9678fd 100644 --- a/src/plugins/nonebot_bison/config_manager.py +++ b/src/plugins/nonebot_bison/config_manager.py @@ -15,9 +15,10 @@ from nonebot.permission import SUPERUSER from nonebot.rule import to_me from nonebot.typing import T_State +from .apis import check_sub_target from .config import config from .config.db_config import SubscribeDupException -from .platform import Platform, check_sub_target, platform_manager +from .platform import Platform, platform_manager from .plugin_config import plugin_config from .types import Category, Target, User from .utils import parse_text @@ -117,9 +118,7 @@ def do_add_sub(add_sub: Type[Matcher]): ) + "请输入订阅用户的id\n查询id获取方法请回复:“查询”" else: state["id"] = "default" - state["name"] = await platform_manager[state["platform"]].get_target_name( - Target("") - ) + state["name"] = await check_sub_target(state["platform"], Target("")) async def parse_id(event: MessageEvent, state: T_State): if not isinstance(state["id"], Message): diff --git a/src/plugins/nonebot_bison/platform/__init__.py b/src/plugins/nonebot_bison/platform/__init__.py index 60b5e32..e8d7186 100644 --- a/src/plugins/nonebot_bison/platform/__init__.py +++ b/src/plugins/nonebot_bison/platform/__init__.py @@ -2,29 +2,24 @@ from collections import defaultdict from importlib import import_module from pathlib import Path from pkgutil import iter_modules +from typing import DefaultDict, Type -from .platform import NoTargetGroup, Platform +from .platform import Platform, make_no_target_group _package_dir = str(Path(__file__).resolve().parent) for (_, module_name, _) in iter_modules([_package_dir]): import_module(f"{__name__}.{module_name}") -async def check_sub_target(target_type, target): - return await platform_manager[target_type].get_target_name(target) - - -_platform_list = defaultdict(list) +_platform_list: DefaultDict[str, list[Type[Platform]]] = defaultdict(list) for _platform in Platform.registry: if not _platform.enabled: continue _platform_list[_platform.platform_name].append(_platform) -platform_manager: dict[str, Platform] = dict() +platform_manager: dict[str, Type[Platform]] = dict() for name, platform_list in _platform_list.items(): if len(platform_list) == 1: - platform_manager[name] = platform_list[0]() + platform_manager[name] = platform_list[0] else: - platform_manager[name] = NoTargetGroup( - [_platform() for _platform in platform_list] - ) + platform_manager[name] = make_no_target_group(platform_list) diff --git a/src/plugins/nonebot_bison/platform/arknights.py b/src/plugins/nonebot_bison/platform/arknights.py index 90c5785..3b7d9a3 100644 --- a/src/plugins/nonebot_bison/platform/arknights.py +++ b/src/plugins/nonebot_bison/platform/arknights.py @@ -1,12 +1,12 @@ import json -from typing import Any +from typing import Any, Optional from bs4 import BeautifulSoup as bs +from httpx import AsyncClient from nonebot.plugin import require from ..post import Post from ..types import Category, RawPost, Target -from ..utils import http_client from ..utils.scheduler_config import SchedulerConfig from .platform import CategoryNotSupport, NewMessage, StatusChange @@ -29,15 +29,17 @@ class Arknights(NewMessage): scheduler = ArknightsSchedConf has_target = False - async def get_target_name(self, _: Target) -> str: + @classmethod + async def get_target_name( + cls, client: AsyncClient, target: Target + ) -> Optional[str]: return "明日方舟游戏信息" async def get_sub_list(self, _) -> list[RawPost]: - async with http_client() as client: - raw_data = await client.get( - "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/announcement.meta.json" - ) - return json.loads(raw_data.text)["announceList"] + raw_data = await self.client.get( + "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/announcement.meta.json" + ) + return json.loads(raw_data.text)["announceList"] def get_id(self, post: RawPost) -> Any: return post["announceId"] @@ -51,8 +53,7 @@ class Arknights(NewMessage): async def parse(self, raw_post: RawPost) -> Post: announce_url = raw_post["webUrl"] text = "" - async with http_client() as client: - raw_html = await client.get(announce_url) + raw_html = await self.client.get(announce_url) soup = bs(raw_html.text, "html.parser") pics = [] if soup.find("div", class_="standerd-container"): @@ -101,17 +102,19 @@ class AkVersion(StatusChange): scheduler = ArknightsSchedConf has_target = False - async def get_target_name(self, _: Target) -> str: + @classmethod + async def get_target_name( + cls, client: AsyncClient, target: Target + ) -> Optional[str]: return "明日方舟游戏信息" async def get_status(self, _): - async with http_client() as client: - res_ver = await client.get( - "https://ak-conf.hypergryph.com/config/prod/official/IOS/version" - ) - res_preanounce = await client.get( - "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json" - ) + res_ver = await self.client.get( + "https://ak-conf.hypergryph.com/config/prod/official/IOS/version" + ) + res_preanounce = await self.client.get( + "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json" + ) res = res_ver.json() res.update(res_preanounce.json()) return res @@ -156,13 +159,17 @@ class MonsterSiren(NewMessage): scheduler = ArknightsSchedConf has_target = False - async def get_target_name(self, _: Target) -> str: + @classmethod + async def get_target_name( + cls, client: AsyncClient, target: Target + ) -> Optional[str]: return "明日方舟游戏信息" async def get_sub_list(self, _) -> list[RawPost]: - async with http_client() as client: - raw_data = await client.get("https://monster-siren.hypergryph.com/api/news") - return raw_data.json()["data"]["list"] + raw_data = await self.client.get( + "https://monster-siren.hypergryph.com/api/news" + ) + return raw_data.json()["data"]["list"] def get_id(self, post: RawPost) -> Any: return post["cid"] @@ -175,16 +182,15 @@ class MonsterSiren(NewMessage): async def parse(self, raw_post: RawPost) -> Post: url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}' - async with http_client() as client: - res = await client.get( - f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}' - ) - raw_data = res.json() - content = raw_data["data"]["content"] - content = content.replace("

", "

\n") - soup = bs(content, "html.parser") - imgs = list(map(lambda x: x["src"], soup("img"))) - text = f'{raw_post["title"]}\n{soup.text.strip()}' + res = await self.client.get( + f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}' + ) + raw_data = res.json() + content = raw_data["data"]["content"] + content = content.replace("

", "

\n") + soup = bs(content, "html.parser") + imgs = list(map(lambda x: x["src"], soup("img"))) + text = f'{raw_post["title"]}\n{soup.text.strip()}' return Post( "monster-siren", text=text, @@ -207,15 +213,17 @@ class TerraHistoricusComic(NewMessage): scheduler = ArknightsSchedConf has_target = False - async def get_target_name(self, _: Target) -> str: + @classmethod + async def get_target_name( + cls, client: AsyncClient, target: Target + ) -> Optional[str]: return "明日方舟游戏信息" async def get_sub_list(self, _) -> list[RawPost]: - async with http_client() as client: - raw_data = await client.get( - "https://terra-historicus.hypergryph.com/api/recentUpdate" - ) - return raw_data.json()["data"] + raw_data = await self.client.get( + "https://terra-historicus.hypergryph.com/api/recentUpdate" + ) + return raw_data.json()["data"] def get_id(self, post: RawPost) -> Any: return f'{post["comicCid"]}/{post["episodeCid"]}' diff --git a/src/plugins/nonebot_bison/platform/bilibili.py b/src/plugins/nonebot_bison/platform/bilibili.py index 7799e08..fe27414 100644 --- a/src/plugins/nonebot_bison/platform/bilibili.py +++ b/src/plugins/nonebot_bison/platform/bilibili.py @@ -5,6 +5,7 @@ from datetime import datetime, timedelta from typing import Any, Callable, Optional import httpx +from httpx import AsyncClient from nonebot.log import logger from ..post import Post @@ -20,35 +21,36 @@ class BilibiliSchedConf(SchedulerConfig): schedule_type = "interval" schedule_setting = {"seconds": 10} - -from .platform import CategoryNotSupport, NewMessage, StatusChange - - -class _BilibiliClient: - - _http_client: httpx.AsyncClient - _client_refresh_time: Optional[datetime] + _client_refresh_time: datetime cookie_expire_time = timedelta(hours=5) + def __init__(self): + self._client_refresh_time = datetime( + year=2000, month=1, day=1 + ) # an expired time + super().__init__() + async def _init_session(self): - self._http_client = httpx.AsyncClient(**http_args) - res = await self._http_client.get("https://www.bilibili.com/") + res = await self.default_http_client.get("https://www.bilibili.com/") if res.status_code != 200: logger.warning("unable to refresh temp cookie") else: self._client_refresh_time = datetime.now() async def _refresh_client(self): - if ( - getattr(self, "_client_refresh_time", None) is None - or datetime.now() - self._client_refresh_time - > self.cookie_expire_time # type:ignore - or self._http_client is None - ): + if datetime.now() - self._client_refresh_time > self.cookie_expire_time: await self._init_session() + async def get_client(self, target: Target) -> AsyncClient: + await self._refresh_client() + return await super().get_client(target) -class Bilibili(_BilibiliClient, NewMessage): + async def get_query_name_client(self) -> AsyncClient: + await self._refresh_client() + return await super().get_query_name_client() + + +class Bilibili(NewMessage): categories = { 1: "一般动态", @@ -67,17 +69,11 @@ class Bilibili(_BilibiliClient, NewMessage): has_target = True parse_target_promot = "请输入用户主页的链接" - def ensure_client(fun: Callable): # type:ignore - @functools.wraps(fun) - async def wrapped(self, *args, **kwargs): - await self._refresh_client() - return await fun(self, *args, **kwargs) - - return wrapped - - @ensure_client - async def get_target_name(self, target: Target) -> Optional[str]: - res = await self._http_client.get( + @classmethod + async def get_target_name( + cls, client: AsyncClient, target: Target + ) -> Optional[str]: + res = await client.get( "https://api.bilibili.com/x/space/acc/info", params={"mid": target} ) res_data = json.loads(res.text) @@ -85,18 +81,18 @@ class Bilibili(_BilibiliClient, NewMessage): return None return res_data["data"]["name"] - async def parse_target(self, target_text: str) -> Target: + @classmethod + async def parse_target(cls, target_text: str) -> Target: if re.match(r"\d+", target_text): return Target(target_text) elif m := re.match(r"(?:https?://)?space\.bilibili\.com/(\d+)", target_text): return Target(m.group(1)) else: - raise self.ParseTargetException() + raise cls.ParseTargetException() - @ensure_client async def get_sub_list(self, target: Target) -> list[RawPost]: params = {"host_uid": target, "offset": 0, "need_top": 0} - res = await self._http_client.get( + res = await self.client.get( "https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history", params=params, timeout=4.0, @@ -202,7 +198,7 @@ class Bilibili(_BilibiliClient, NewMessage): return Post("bilibili", text=text, url=url, pics=pic, target_name=target_name) -class Bilibililive(_BilibiliClient, StatusChange): +class Bilibililive(StatusChange): # Author : Sichongzou # Date : 2022-5-18 8:54 # Description : bilibili开播提醒 @@ -216,17 +212,11 @@ class Bilibililive(_BilibiliClient, StatusChange): name = "Bilibili直播" has_target = True - def ensure_client(fun: Callable): # type:ignore - @functools.wraps(fun) - async def wrapped(self, *args, **kwargs): - await self._refresh_client() - return await fun(self, *args, **kwargs) - - return wrapped - - @ensure_client - async def get_target_name(self, target: Target) -> Optional[str]: - res = await self._http_client.get( + @classmethod + async def get_target_name( + cls, client: AsyncClient, target: Target + ) -> Optional[str]: + res = await client.get( "https://api.bilibili.com/x/space/acc/info", params={"mid": target} ) res_data = json.loads(res.text) @@ -234,10 +224,9 @@ class Bilibililive(_BilibiliClient, StatusChange): return None return res_data["data"]["name"] - @ensure_client async def get_status(self, target: Target): params = {"mid": target} - res = await self._http_client.get( + res = await self.client.get( "https://api.bilibili.com/x/space/acc/info", params=params, timeout=4.0, @@ -279,7 +268,7 @@ class Bilibililive(_BilibiliClient, StatusChange): ) -class BilibiliBangumi(_BilibiliClient, StatusChange): +class BilibiliBangumi(StatusChange): categories = {} platform_name = "bilibili-bangumi" @@ -293,23 +282,18 @@ class BilibiliBangumi(_BilibiliClient, StatusChange): _url = "https://api.bilibili.com/pgc/review/user" - def ensure_client(fun: Callable): # type:ignore - @functools.wraps(fun) - async def wrapped(self, *args, **kwargs): - await self._refresh_client() - return await fun(self, *args, **kwargs) - - return wrapped - - @ensure_client - async def get_target_name(self, target: Target) -> Optional[str]: - res = await self._http_client.get(self._url, params={"media_id": target}) + @classmethod + async def get_target_name( + cls, client: AsyncClient, target: Target + ) -> Optional[str]: + res = await client.get(cls._url, params={"media_id": target}) res_data = res.json() if res_data["code"]: return None return res_data["result"]["media"]["title"] - async def parse_target(self, target_string: str) -> Target: + @classmethod + async def parse_target(cls, target_string: str) -> Target: if re.match(r"\d+", target_string): return Target(target_string) elif m := re.match(r"md(\d+)", target_string): @@ -318,11 +302,10 @@ class BilibiliBangumi(_BilibiliClient, StatusChange): r"(?:https?://)?www\.bilibili\.com/bangumi/media/md(\d+)/", target_string ): return Target(m.group(1)) - raise self.ParseTargetException() + raise cls.ParseTargetException() - @ensure_client async def get_status(self, target: Target): - res = await self._http_client.get( + res = await self.client.get( self._url, params={"media_id": target}, timeout=4.0, @@ -343,9 +326,8 @@ class BilibiliBangumi(_BilibiliClient, StatusChange): else: return [] - @ensure_client async def parse(self, raw_post: RawPost) -> Post: - detail_res = await self._http_client.get( + detail_res = await self.client.get( f'http://api.bilibili.com/pgc/view/web/season?season_id={raw_post["season_id"]}' ) detail_dict = detail_res.json() diff --git a/src/plugins/nonebot_bison/platform/ff14.py b/src/plugins/nonebot_bison/platform/ff14.py index bba784c..a8dfe40 100644 --- a/src/plugins/nonebot_bison/platform/ff14.py +++ b/src/plugins/nonebot_bison/platform/ff14.py @@ -1,4 +1,6 @@ -from typing import Any +from typing import Any, Optional + +from httpx import AsyncClient from ..post import Post from ..types import RawPost, Target @@ -18,7 +20,10 @@ class FF14(NewMessage): scheduler = scheduler("interval", {"seconds": 60}) has_target = False - async def get_target_name(self, _: Target) -> str: + @classmethod + async def get_target_name( + cls, client: AsyncClient, target: Target + ) -> Optional[str]: return "最终幻想XIV官方公告" async def get_sub_list(self, _) -> list[RawPost]: diff --git a/src/plugins/nonebot_bison/platform/mcbbsnews.py b/src/plugins/nonebot_bison/platform/mcbbsnews.py index b61f146..cb762a0 100644 --- a/src/plugins/nonebot_bison/platform/mcbbsnews.py +++ b/src/plugins/nonebot_bison/platform/mcbbsnews.py @@ -1,9 +1,10 @@ import re import time -from typing import Literal +from typing import Literal, Optional import httpx from bs4 import BeautifulSoup, NavigableString, Tag +from httpx import AsyncClient from ..post import Post from ..types import Category, RawPost, Target @@ -42,8 +43,11 @@ class McbbsNews(NewMessage): scheduler = scheduler("interval", {"hours": 1}) has_target = False - async def get_target_name(self, _: Target) -> str: - return self.name + @classmethod + async def get_target_name( + cls, client: AsyncClient, target: Target + ) -> Optional[str]: + return cls.name async def get_sub_list(self, _: Target) -> list[RawPost]: url = "https://www.mcbbs.net/forum-news-1.html" diff --git a/src/plugins/nonebot_bison/platform/ncm.py b/src/plugins/nonebot_bison/platform/ncm.py new file mode 100644 index 0000000..68e4189 --- /dev/null +++ b/src/plugins/nonebot_bison/platform/ncm.py @@ -0,0 +1,146 @@ +import re +from typing import Any, Optional + +from httpx import AsyncClient + +from ..post import Post +from ..types import RawPost, Target +from ..utils import SchedulerConfig, http_client +from .platform import NewMessage + + +class NcmSchedConf(SchedulerConfig): + + name = "music.163.com" + schedule_type = "interval" + schedule_setting = {"minutes": 1} + + +class NcmArtist(NewMessage): + + categories = {} + platform_name = "ncm-artist" + enable_tag = False + enabled = True + is_common = True + scheduler = NcmSchedConf + name = "网易云-歌手" + has_target = True + parse_target_promot = "请输入歌手主页(包含数字ID)的链接" + + @classmethod + async def get_target_name( + cls, client: AsyncClient, target: Target + ) -> Optional[str]: + async with http_client() as client: + res = await client.get( + "https://music.163.com/api/artist/albums/{}".format(target), + headers={"Referer": "https://music.163.com/"}, + ) + res_data = res.json() + if res_data["code"] != 200: + return + return res_data["artist"]["name"] + + @classmethod + async def parse_target(cls, target_text: str) -> Target: + if re.match(r"^\d+$", target_text): + return Target(target_text) + elif match := re.match( + r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text + ): + return Target(match.group(1)) + else: + raise cls.ParseTargetException() + + async def get_sub_list(self, target: Target) -> list[RawPost]: + async with http_client() as client: + res = await client.get( + "https://music.163.com/api/artist/albums/{}".format(target), + headers={"Referer": "https://music.163.com/"}, + ) + res_data = res.json() + if res_data["code"] != 200: + return [] + else: + return res_data["hotAlbums"] + + def get_id(self, post: RawPost) -> Any: + return post["id"] + + def get_date(self, post: RawPost) -> int: + return post["publishTime"] // 1000 + + async def parse(self, raw_post: RawPost) -> Post: + text = "新专辑发布:{}".format(raw_post["name"]) + target_name = raw_post["artist"]["name"] + pics = [raw_post["picUrl"]] + url = "https://music.163.com/#/album?id={}".format(raw_post["id"]) + return Post( + "ncm-artist", text=text, url=url, pics=pics, target_name=target_name + ) + + +class NcmRadio(NewMessage): + + categories = {} + platform_name = "ncm-radio" + enable_tag = False + enabled = True + is_common = False + scheduler = NcmSchedConf + name = "网易云-电台" + has_target = True + parse_target_promot = "请输入主播电台主页(包含数字ID)的链接" + + @classmethod + async def get_target_name( + cls, client: AsyncClient, target: Target + ) -> Optional[str]: + async with http_client() as client: + res = await client.post( + "http://music.163.com/api/dj/program/byradio", + headers={"Referer": "https://music.163.com/"}, + data={"radioId": target, "limit": 1000, "offset": 0}, + ) + res_data = res.json() + if res_data["code"] != 200 or res_data["programs"] == 0: + return + return res_data["programs"][0]["radio"]["name"] + + @classmethod + async def parse_target(cls, target_text: str) -> Target: + if re.match(r"^\d+$", target_text): + return Target(target_text) + elif match := re.match( + r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text + ): + return Target(match.group(1)) + else: + raise cls.ParseTargetException() + + async def get_sub_list(self, target: Target) -> list[RawPost]: + async with http_client() as client: + res = await client.post( + "http://music.163.com/api/dj/program/byradio", + headers={"Referer": "https://music.163.com/"}, + data={"radioId": target, "limit": 1000, "offset": 0}, + ) + res_data = res.json() + if res_data["code"] != 200: + return [] + else: + return res_data["programs"] + + def get_id(self, post: RawPost) -> Any: + return post["id"] + + def get_date(self, post: RawPost) -> int: + return post["createTime"] // 1000 + + async def parse(self, raw_post: RawPost) -> Post: + text = "网易云电台更新:{}".format(raw_post["name"]) + target_name = raw_post["radio"]["name"] + pics = [raw_post["coverUrl"]] + url = "https://music.163.com/#/program/{}".format(raw_post["id"]) + return Post("ncm-radio", text=text, url=url, pics=pics, target_name=target_name) diff --git a/src/plugins/nonebot_bison/platform/ncm_artist.py b/src/plugins/nonebot_bison/platform/ncm_artist.py deleted file mode 100644 index a15349e..0000000 --- a/src/plugins/nonebot_bison/platform/ncm_artist.py +++ /dev/null @@ -1,75 +0,0 @@ -import re -from typing import Any, Optional - -from ..post import Post -from ..types import RawPost, Target -from ..utils import SchedulerConfig, http_client -from .platform import NewMessage - - -class NcmSchedConf(SchedulerConfig): - - name = "music.163.com" - schedule_type = "interval" - schedule_setting = {"minutes": 1} - - -class NcmArtist(NewMessage): - - categories = {} - platform_name = "ncm-artist" - enable_tag = False - enabled = True - is_common = True - scheduler = NcmSchedConf - name = "网易云-歌手" - has_target = True - parse_target_promot = "请输入歌手主页(包含数字ID)的链接" - - async def get_target_name(self, target: Target) -> Optional[str]: - async with http_client() as client: - res = await client.get( - "https://music.163.com/api/artist/albums/{}".format(target), - headers={"Referer": "https://music.163.com/"}, - ) - res_data = res.json() - if res_data["code"] != 200: - return - return res_data["artist"]["name"] - - async def parse_target(self, target_text: str) -> Target: - if re.match(r"^\d+$", target_text): - return Target(target_text) - elif match := re.match( - r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text - ): - return Target(match.group(1)) - else: - raise self.ParseTargetException() - - async def get_sub_list(self, target: Target) -> list[RawPost]: - async with http_client() as client: - res = await client.get( - "https://music.163.com/api/artist/albums/{}".format(target), - headers={"Referer": "https://music.163.com/"}, - ) - res_data = res.json() - if res_data["code"] != 200: - return [] - else: - return res_data["hotAlbums"] - - def get_id(self, post: RawPost) -> Any: - return post["id"] - - def get_date(self, post: RawPost) -> int: - return post["publishTime"] // 1000 - - async def parse(self, raw_post: RawPost) -> Post: - text = "新专辑发布:{}".format(raw_post["name"]) - target_name = raw_post["artist"]["name"] - pics = [raw_post["picUrl"]] - url = "https://music.163.com/#/album?id={}".format(raw_post["id"]) - return Post( - "ncm-artist", text=text, url=url, pics=pics, target_name=target_name - ) diff --git a/src/plugins/nonebot_bison/platform/ncm_radio.py b/src/plugins/nonebot_bison/platform/ncm_radio.py deleted file mode 100644 index 4170eb2..0000000 --- a/src/plugins/nonebot_bison/platform/ncm_radio.py +++ /dev/null @@ -1,69 +0,0 @@ -import re -from typing import Any, Optional - -from ..post import Post -from ..types import RawPost, Target -from ..utils import http_client -from .ncm_artist import NcmSchedConf -from .platform import NewMessage - - -class NcmRadio(NewMessage): - - categories = {} - platform_name = "ncm-radio" - enable_tag = False - enabled = True - is_common = False - scheduler = NcmSchedConf - name = "网易云-电台" - has_target = True - parse_target_promot = "请输入主播电台主页(包含数字ID)的链接" - - async def get_target_name(self, target: Target) -> Optional[str]: - async with http_client() as client: - res = await client.post( - "http://music.163.com/api/dj/program/byradio", - headers={"Referer": "https://music.163.com/"}, - data={"radioId": target, "limit": 1000, "offset": 0}, - ) - res_data = res.json() - if res_data["code"] != 200 or res_data["programs"] == 0: - return - return res_data["programs"][0]["radio"]["name"] - - async def parse_target(self, target_text: str) -> Target: - if re.match(r"^\d+$", target_text): - return Target(target_text) - elif match := re.match( - r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text - ): - return Target(match.group(1)) - else: - raise self.ParseTargetException() - - async def get_sub_list(self, target: Target) -> list[RawPost]: - async with http_client() as client: - res = await client.post( - "http://music.163.com/api/dj/program/byradio", - headers={"Referer": "https://music.163.com/"}, - data={"radioId": target, "limit": 1000, "offset": 0}, - ) - res_data = res.json() - if res_data["code"] != 200: - return [] - else: - return res_data["programs"] - - def get_id(self, post: RawPost) -> Any: - return post["id"] - - def get_date(self, post: RawPost) -> int: - return post["createTime"] // 1000 - - async def parse(self, raw_post: RawPost) -> Post: - text = "网易云电台更新:{}".format(raw_post["name"]) - target_name = raw_post["radio"]["name"] - pics = [raw_post["coverUrl"]] - url = "https://music.163.com/#/program/{}".format(raw_post["id"]) - return Post("ncm-radio", text=text, url=url, pics=pics, target_name=target_name) diff --git a/src/plugins/nonebot_bison/platform/platform.py b/src/plugins/nonebot_bison/platform/platform.py index 915ca50..8d7d3be 100644 --- a/src/plugins/nonebot_bison/platform/platform.py +++ b/src/plugins/nonebot_bison/platform/platform.py @@ -1,12 +1,14 @@ import json import ssl import time +import typing from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass -from typing import Any, Collection, Literal, Optional, Type +from typing import Any, Collection, Optional, Type import httpx +from httpx import AsyncClient from nonebot.log import logger from ..plugin_config import plugin_config @@ -34,11 +36,23 @@ class RegistryMeta(type): super().__init__(name, bases, namespace, **kwargs) -class RegistryABCMeta(RegistryMeta, ABC): +class PlatformMeta(RegistryMeta): + + categories: dict[Category, str] + + def __init__(cls, name, bases, namespace, **kwargs): + cls.reverse_category = {} + if hasattr(cls, "categories") and cls.categories: + for key, val in cls.categories.items(): + cls.reverse_category[val] = key + super().__init__(name, bases, namespace, **kwargs) + + +class PlatformABCMeta(PlatformMeta, ABC): ... -class Platform(metaclass=RegistryABCMeta, base=True): +class Platform(metaclass=PlatformABCMeta, base=True): scheduler: Type[SchedulerConfig] is_common: bool @@ -50,9 +64,15 @@ class Platform(metaclass=RegistryABCMeta, base=True): store: dict[Target, Any] platform_name: str parse_target_promot: Optional[str] = None + registry: list[Type["Platform"]] + client: AsyncClient + reverse_category: dict[str, Category] + @classmethod @abstractmethod - async def get_target_name(self, target: Target) -> Optional[str]: + async def get_target_name( + cls, client: AsyncClient, target: Target + ) -> Optional[str]: ... @abstractmethod @@ -88,17 +108,16 @@ class Platform(metaclass=RegistryABCMeta, base=True): "actually function called" return await self.parse(raw_post) - def __init__(self): + def __init__(self, client: AsyncClient): super().__init__() - self.reverse_category = {} - for key, val in self.categories.items(): - self.reverse_category[val] = key self.store = dict() + self.client = client class ParseTargetException(Exception): pass - async def parse_target(self, target_string: str) -> Target: + @classmethod + async def parse_target(cls, target_string: str) -> Target: return Target(target_string) @abstractmethod @@ -188,8 +207,8 @@ class Platform(metaclass=RegistryABCMeta, base=True): class MessageProcess(Platform, abstract=True): "General message process fetch, parse, filter progress" - def __init__(self): - super().__init__() + def __init__(self, client: AsyncClient): + super().__init__(client) self.parse_cache: dict[Any, Post] = dict() @abstractmethod @@ -362,55 +381,82 @@ class SimplePost(MessageProcess, abstract=True): return res -class NoTargetGroup(Platform, abstract=True): - enable_tag = False +def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]: + + if typing.TYPE_CHECKING: + + class NoTargetGroup(Platform, abstract=True): + platform_list: list[Type[Platform]] + platform_obj_list: list[Platform] + DUMMY_STR = "_DUMMY" - enabled = True - has_target = False - def __init__(self, platform_list: list[Platform]): - self.platform_list = platform_list - self.platform_name = platform_list[0].platform_name - name = self.DUMMY_STR - self.categories = {} - categories_keys = set() - self.scheduler = platform_list[0].scheduler - for platform in platform_list: - if platform.has_target: - raise RuntimeError( - "Platform {} should have no target".format(platform.name) - ) - if name == self.DUMMY_STR: - name = platform.name - elif name != platform.name: - raise RuntimeError( - "Platform name for {} not fit".format(self.platform_name) - ) - platform_category_key_set = set(platform.categories.keys()) - if platform_category_key_set & categories_keys: - raise RuntimeError( - "Platform categories for {} duplicate".format(self.platform_name) - ) - categories_keys |= platform_category_key_set - self.categories.update(platform.categories) - if platform.scheduler != self.scheduler: - raise RuntimeError( - "Platform scheduler for {} not fit".format(self.platform_name) - ) - self.name = name - self.is_common = platform_list[0].is_common - super().__init__() + platform_name = platform_list[0].platform_name + name = DUMMY_STR + categories_keys = set() + categories = {} + scheduler = platform_list[0].scheduler - def __str__(self): + for platform in platform_list: + if platform.has_target: + raise RuntimeError( + "Platform {} should have no target".format(platform.name) + ) + if name == DUMMY_STR: + name = platform.name + elif name != platform.name: + raise RuntimeError("Platform name for {} not fit".format(platform_name)) + platform_category_key_set = set(platform.categories.keys()) + if platform_category_key_set & categories_keys: + raise RuntimeError( + "Platform categories for {} duplicate".format(platform_name) + ) + categories_keys |= platform_category_key_set + categories.update(platform.categories) + if platform.scheduler != scheduler: + raise RuntimeError( + "Platform scheduler for {} not fit".format(platform_name) + ) + + def __init__(self: "NoTargetGroup", client: AsyncClient): + Platform.__init__(self, client) + self.platform_obj_list = [] + for platform_class in self.platform_list: + self.platform_obj_list.append(platform_class(client)) + + def __str__(self: "NoTargetGroup") -> str: return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]" - async def get_target_name(self, _): - return await self.platform_list[0].get_target_name(_) + @classmethod + async def get_target_name(cls, client: AsyncClient, target: Target): + return await platform_list[0].get_target_name(client, target) - async def fetch_new_post(self, target, users): + async def fetch_new_post( + self: "NoTargetGroup", target: Target, users: list[UserSubInfo] + ): res = defaultdict(list) - for platform in self.platform_list: + for platform in self.platform_obj_list: platform_res = await platform.fetch_new_post(target=target, users=users) for user, posts in platform_res: res[user].extend(posts) return [[key, val] for key, val in res.items()] + + return type( + "NoTargetGroup", + (Platform,), + { + "platform_list": platform_list, + "platform_name": platform_list[0].platform_name, + "name": name, + "categories": categories, + "scheduler": scheduler, + "is_common": platform_list[0].is_common, + "enabled": True, + "has_target": False, + "enable_tag": False, + "__init__": __init__, + "get_target_name": get_target_name, + "fetch_new_post": fetch_new_post, + }, + abstract=True, + ) diff --git a/src/plugins/nonebot_bison/platform/rss.py b/src/plugins/nonebot_bison/platform/rss.py index 54a5642..1ab76f3 100644 --- a/src/plugins/nonebot_bison/platform/rss.py +++ b/src/plugins/nonebot_bison/platform/rss.py @@ -3,6 +3,7 @@ from typing import Any, Optional import feedparser from bs4 import BeautifulSoup as bs +from httpx import AsyncClient from ..post import Post from ..types import RawPost, Target @@ -21,7 +22,10 @@ class Rss(NewMessage): scheduler = scheduler("interval", {"seconds": 30}) has_target = True - async def get_target_name(self, target: Target) -> Optional[str]: + @classmethod + async def get_target_name( + cls, client: AsyncClient, target: Target + ) -> Optional[str]: async with http_client() as client: res = await client.get(target, timeout=10.0) feed = feedparser.parse(res.text) diff --git a/src/plugins/nonebot_bison/platform/weibo.py b/src/plugins/nonebot_bison/platform/weibo.py index cad1361..87bf733 100644 --- a/src/plugins/nonebot_bison/platform/weibo.py +++ b/src/plugins/nonebot_bison/platform/weibo.py @@ -5,6 +5,7 @@ from datetime import datetime from typing import Any, Optional from bs4 import BeautifulSoup as bs +from httpx import AsyncClient from nonebot.log import logger from ..post import Post @@ -36,7 +37,10 @@ class Weibo(NewMessage): has_target = True parse_target_promot = "请输入用户主页(包含数字UID)的链接" - async def get_target_name(self, target: Target) -> Optional[str]: + @classmethod + async def get_target_name( + cls, client: AsyncClient, target: Target + ) -> Optional[str]: async with http_client() as client: param = {"containerid": "100505" + target} res = await client.get( @@ -48,14 +52,15 @@ class Weibo(NewMessage): else: return None - async def parse_target(self, target_text: str) -> Target: + @classmethod + async def parse_target(cls, target_text: str) -> Target: if re.match(r"\d+", target_text): return Target(target_text) elif match := re.match(r"(?:https?://)?weibo\.com/u/(\d+)", target_text): # 都2202年了应该不会有http了吧,不过还是防一手 return Target(match.group(1)) else: - raise self.ParseTargetException() + raise cls.ParseTargetException() async def get_sub_list(self, target: Target) -> list[RawPost]: async with http_client() as client: diff --git a/src/plugins/nonebot_bison/scheduler/scheduler.py b/src/plugins/nonebot_bison/scheduler/scheduler.py index 6b61242..422f453 100644 --- a/src/plugins/nonebot_bison/scheduler/scheduler.py +++ b/src/plugins/nonebot_bison/scheduler/scheduler.py @@ -7,7 +7,6 @@ from nonebot.log import logger from ..config import config from ..platform import platform_manager -from ..platform.platform import Platform from ..send import send_msgs from ..types import Target from ..utils import SchedulerConfig @@ -36,6 +35,7 @@ class Scheduler: logger.error(f"scheduler config [{self.name}] not found, exiting") raise RuntimeError(f"{self.name} not found") self.scheduler_config = scheduler_config + self.scheduler_config_obj = self.scheduler_config() self.schedulable_list = [] for platform_name, target in schedulables: self.schedulable_list.append( @@ -86,7 +86,10 @@ class Scheduler: send_userinfo_list = await config.get_platform_target_subscribers( schedulable.platform_name, schedulable.target ) - to_send = await platform_manager[schedulable.platform_name].do_fetch_new_post( + platform_obj = platform_manager[schedulable.platform_name]( + await self.scheduler_config_obj.get_client(schedulable.target) + ) + to_send = await platform_obj.do_fetch_new_post( schedulable.target, send_userinfo_list ) if not to_send: diff --git a/src/plugins/nonebot_bison/utils/scheduler_config.py b/src/plugins/nonebot_bison/utils/scheduler_config.py index b9a5911..57360fa 100644 --- a/src/plugins/nonebot_bison/utils/scheduler_config.py +++ b/src/plugins/nonebot_bison/utils/scheduler_config.py @@ -1,5 +1,10 @@ from typing import Literal, Type +from httpx import AsyncClient + +from ..types import Target +from .http import http_client + class SchedulerConfig: @@ -10,6 +15,15 @@ class SchedulerConfig: def __str__(self): return f"[{self.name}]-{self.name}-{self.schedule_setting}" + def __init__(self): + self.default_http_client = http_client() + + async def get_client(self, target: Target) -> AsyncClient: + return self.default_http_client + + async def get_query_name_client(self) -> AsyncClient: + return self.default_http_client + def scheduler( schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict diff --git a/tests/platforms/test_arknights.py b/tests/platforms/test_arknights.py index ba81db0..d4eda02 100644 --- a/tests/platforms/test_arknights.py +++ b/tests/platforms/test_arknights.py @@ -1,6 +1,6 @@ import pytest import respx -from httpx import Response +from httpx import AsyncClient, Response from nonebug.app import App from .utils import get_file, get_json @@ -10,7 +10,7 @@ from .utils import get_file, get_json def arknights(app: App): from nonebot_bison.platform import platform_manager - return platform_manager["arknights"] + return platform_manager["arknights"](AsyncClient()) @pytest.fixture(scope="module") diff --git a/tests/platforms/test_bilibili.py b/tests/platforms/test_bilibili.py index 23ca085..c074dc8 100644 --- a/tests/platforms/test_bilibili.py +++ b/tests/platforms/test_bilibili.py @@ -3,7 +3,7 @@ from datetime import datetime import pytest import respx -from httpx import Response +from httpx import AsyncClient, Response from nonebug.app import App from pytz import timezone @@ -23,7 +23,7 @@ if typing.TYPE_CHECKING: def bilibili(app: App): from nonebot_bison.platform import platform_manager - return platform_manager["bilibili"] + return platform_manager["bilibili"](AsyncClient()) @pytest.mark.asyncio diff --git a/tests/platforms/test_bilibili_bangumi.py b/tests/platforms/test_bilibili_bangumi.py index 6b783f8..3e2db55 100644 --- a/tests/platforms/test_bilibili_bangumi.py +++ b/tests/platforms/test_bilibili_bangumi.py @@ -2,7 +2,7 @@ import typing import pytest import respx -from httpx import Response +from httpx import AsyncClient, Response from nonebug.app import App from .utils import get_json @@ -15,7 +15,7 @@ if typing.TYPE_CHECKING: def bili_bangumi(app: App): from nonebot_bison.platform import platform_manager - return platform_manager["bilibili-bangumi"] + return platform_manager["bilibili-bangumi"](AsyncClient()) @pytest.mark.asyncio diff --git a/tests/platforms/test_bilibili_live.py b/tests/platforms/test_bilibili_live.py index b369ac3..e86875e 100644 --- a/tests/platforms/test_bilibili_live.py +++ b/tests/platforms/test_bilibili_live.py @@ -1,6 +1,6 @@ import pytest import respx -from httpx import Response +from httpx import AsyncClient, Response from nonebug.app import App from .utils import get_json @@ -10,7 +10,7 @@ from .utils import get_json def bili_live(app: App): from nonebot_bison.platform import platform_manager - return platform_manager["bilibili-live"] + return platform_manager["bilibili-live"](AsyncClient()) @pytest.mark.asyncio diff --git a/tests/platforms/test_ff14.py b/tests/platforms/test_ff14.py index 4a7d665..ed6ac2e 100644 --- a/tests/platforms/test_ff14.py +++ b/tests/platforms/test_ff14.py @@ -1,6 +1,6 @@ import pytest import respx -from httpx import Response +from httpx import AsyncClient, Response from nonebug.app import App from .utils import get_json @@ -10,7 +10,7 @@ from .utils import get_json def ff14(app: App): from nonebot_bison.platform import platform_manager - return platform_manager["ff14"] + return platform_manager["ff14"](AsyncClient()) @pytest.fixture(scope="module") diff --git a/tests/platforms/test_mcbbsnews.py b/tests/platforms/test_mcbbsnews.py index d4b6d94..e551c10 100644 --- a/tests/platforms/test_mcbbsnews.py +++ b/tests/platforms/test_mcbbsnews.py @@ -1,6 +1,6 @@ import pytest import respx -from httpx import Response +from httpx import AsyncClient, Response from nonebug.app import App from .utils import get_file, get_json @@ -10,7 +10,7 @@ from .utils import get_file, get_json def mcbbsnews(app: App): from nonebot_bison.platform import platform_manager - return platform_manager["mcbbsnews"] + return platform_manager["mcbbsnews"](AsyncClient()) @pytest.fixture(scope="module") diff --git a/tests/platforms/test_ncm_artist.py b/tests/platforms/test_ncm_artist.py index 8f8f4cd..78296de 100644 --- a/tests/platforms/test_ncm_artist.py +++ b/tests/platforms/test_ncm_artist.py @@ -3,20 +3,20 @@ import typing import pytest import respx -from httpx import Response +from httpx import AsyncClient, Response from nonebug.app import App from .utils import get_json if typing.TYPE_CHECKING: - from nonebot_bison.platform.ncm_artist import NcmArtist + from nonebot_bison.platform.ncm import NcmArtist @pytest.fixture def ncm_artist(app: App): from nonebot_bison.platform import platform_manager - return platform_manager["ncm-artist"] + return platform_manager["ncm-artist"](AsyncClient()) @pytest.fixture(scope="module") diff --git a/tests/platforms/test_ncm_radio.py b/tests/platforms/test_ncm_radio.py index 9bab7d8..37b2160 100644 --- a/tests/platforms/test_ncm_radio.py +++ b/tests/platforms/test_ncm_radio.py @@ -3,20 +3,20 @@ import typing import pytest import respx -from httpx import Response +from httpx import AsyncClient, Response from nonebug.app import App from .utils import get_json if typing.TYPE_CHECKING: - from nonebot_bison.platform.ncm_radio import NcmRadio + from nonebot_bison.platform.ncm import NcmRadio @pytest.fixture def ncm_radio(app: App): from nonebot_bison.platform import platform_manager - return platform_manager["ncm-radio"] + return platform_manager["ncm-radio"](AsyncClient()) @pytest.fixture(scope="module") diff --git a/tests/platforms/test_platform.py b/tests/platforms/test_platform.py index b4c9677..b940160 100644 --- a/tests/platforms/test_platform.py +++ b/tests/platforms/test_platform.py @@ -53,12 +53,12 @@ def mock_platform_without_cats_tags(app: App): categories = {} has_target = True - def __init__(self): + def __init__(self, client): self.sub_index = 0 - super().__init__() + super().__init__(client) - @staticmethod - async def get_target_name(_: "Target"): + @classmethod + async def get_target_name(cls, client, _: "Target"): return "MockPlatform" def get_id(self, post: "RawPost") -> Any: @@ -82,7 +82,7 @@ def mock_platform_without_cats_tags(app: App): else: return raw_post_list_2 - return MockPlatform() + return MockPlatform(None) @pytest.fixture @@ -112,9 +112,9 @@ def mock_platform(app: App): Category(2): "视频", } - def __init__(self): + def __init__(self, client): self.sub_index = 0 - super().__init__() + super().__init__(client) @staticmethod async def get_target_name(_: "Target"): @@ -147,7 +147,7 @@ def mock_platform(app: App): else: return raw_post_list_2 - return MockPlatform() + return MockPlatform(None) @pytest.fixture @@ -180,9 +180,9 @@ def mock_platform_no_target(app: App, mock_scheduler_conf): has_target = False categories = {Category(1): "转发", Category(2): "视频", Category(3): "不支持"} - def __init__(self): + def __init__(self, client): self.sub_index = 0 - super().__init__() + super().__init__(client) @staticmethod async def get_target_name(_: "Target"): @@ -217,7 +217,7 @@ def mock_platform_no_target(app: App, mock_scheduler_conf): else: return raw_post_list_2 - return MockPlatform() + return MockPlatform @pytest.fixture @@ -241,12 +241,12 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf): Category(5): "leixing5", } - def __init__(self): + def __init__(self, client): self.sub_index = 0 - super().__init__() + super().__init__(client) - @staticmethod - async def get_target_name(_: "Target"): + @classmethod + async def get_target_name(cls, client, _: "Target"): return "MockPlatform" def get_id(self, post: "RawPost") -> Any: @@ -284,7 +284,7 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf): else: return list_2 - return MockPlatform() + return MockPlatform @pytest.fixture @@ -308,9 +308,9 @@ def mock_status_change(app: App): Category(2): "视频", } - def __init__(self): + def __init__(self, client): self.sub_index = 0 - super().__init__() + super().__init__(client) async def get_status(self, _: "Target"): if self.sub_index == 0: @@ -335,7 +335,7 @@ def mock_status_change(app: App): def get_category(self, raw_post): return raw_post["cat"] - return MockPlatform() + return MockPlatform(None) @pytest.mark.asyncio @@ -388,6 +388,7 @@ async def test_new_message_target(mock_platform, user_info_factory): @pytest.mark.asyncio async def test_new_message_no_target(mock_platform_no_target, user_info_factory): + mock_platform_no_target = mock_platform_no_target(None) res1 = await mock_platform_no_target.fetch_new_post( "dummy", [user_info_factory([1, 2], [])] ) @@ -457,11 +458,14 @@ async def test_group( user_info_factory, ): - from nonebot_bison.platform.platform import NoTargetGroup + from nonebot_bison.platform.platform import make_no_target_group from nonebot_bison.post import Post from nonebot_bison.types import Category, RawPost, Tag, Target - group_platform = NoTargetGroup([mock_platform_no_target, mock_platform_no_target_2]) + group_platform_class = make_no_target_group( + [mock_platform_no_target, mock_platform_no_target_2] + ) + group_platform = group_platform_class(None) res1 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])]) assert len(res1) == 0 res2 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])]) diff --git a/tests/platforms/test_platform_tag_filter.py b/tests/platforms/test_platform_tag_filter.py index c9d399a..6e340b7 100644 --- a/tests/platforms/test_platform_tag_filter.py +++ b/tests/platforms/test_platform_tag_filter.py @@ -1,4 +1,5 @@ import pytest +from httpx import AsyncClient from nonebug.app import App from .utils import get_json @@ -14,7 +15,7 @@ def test_cases(): async def test_filter_user_custom_tag(app: App, test_cases): from nonebot_bison.platform import platform_manager - bilibili = platform_manager["bilibili"] + bilibili = platform_manager["bilibili"](AsyncClient()) for case in test_cases: res = bilibili.is_banned_post(**case["case"]) assert res == case["result"] @@ -25,7 +26,7 @@ async def test_filter_user_custom_tag(app: App, test_cases): async def test_tag_separator(app: App): from nonebot_bison.platform import platform_manager - bilibili = platform_manager["bilibili"] + bilibili = platform_manager["bilibili"](AsyncClient()) tags = ["~111", "222", "333", "~444", "555"] res = bilibili.tag_separator(tags) assert res[0] == ["222", "333", "555"] diff --git a/tests/platforms/test_weibo.py b/tests/platforms/test_weibo.py index 26d94bc..69f1974 100644 --- a/tests/platforms/test_weibo.py +++ b/tests/platforms/test_weibo.py @@ -4,7 +4,7 @@ from datetime import datetime import feedparser import pytest import respx -from httpx import Response +from httpx import AsyncClient, Response from nonebug.app import App from pytz import timezone @@ -18,7 +18,7 @@ if typing.TYPE_CHECKING: def weibo(app: App): from nonebot_bison.platform import platform_manager - return platform_manager["weibo"] + return platform_manager["weibo"](AsyncClient()) @pytest.fixture(scope="module") @@ -35,7 +35,7 @@ async def test_get_name(weibo): profile_router.mock( return_value=Response(200, json=get_json("weibo_ak_profile.json")) ) - name = await weibo.get_target_name("6279793937") + name = await weibo.get_target_name(AsyncClient(), "6279793937") assert name == "明日方舟Arknights" diff --git a/tests/test_config_manager_abort.py b/tests/test_config_manager_abort.py index e01fc32..61124a7 100644 --- a/tests/test_config_manager_abort.py +++ b/tests/test_config_manager_abort.py @@ -10,7 +10,7 @@ from .utils import BotReply, fake_admin_user, fake_group_message_event # 选择platform阶段中止 @pytest.mark.asyncio @respx.mock -async def test_abort_add_on_platform(app: App, db_migration): +async def test_abort_add_on_platform(app: App, init_scheduler): from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.message import Message from nonebot_bison.config_manager import add_sub_matcher, common_platform @@ -57,7 +57,7 @@ async def test_abort_add_on_platform(app: App, db_migration): # 输入id阶段中止 @pytest.mark.asyncio @respx.mock -async def test_abort_add_on_id(app: App, db_migration): +async def test_abort_add_on_id(app: App, init_scheduler): from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.message import Message from nonebot_bison.config_manager import add_sub_matcher, common_platform @@ -114,7 +114,7 @@ async def test_abort_add_on_id(app: App, db_migration): # 输入订阅类别阶段中止 @pytest.mark.asyncio @respx.mock -async def test_abort_add_on_cats(app: App, db_migration): +async def test_abort_add_on_cats(app: App, init_scheduler): from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.message import Message from nonebot_bison.config_manager import add_sub_matcher, common_platform @@ -191,7 +191,7 @@ async def test_abort_add_on_cats(app: App, db_migration): # 输入标签阶段中止 @pytest.mark.asyncio @respx.mock -async def test_abort_add_on_tag(app: App, db_migration): +async def test_abort_add_on_tag(app: App, init_scheduler): from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.message import Message from nonebot_bison.config_manager import add_sub_matcher, common_platform diff --git a/tests/test_config_manager_add.py b/tests/test_config_manager_add.py index 383fe51..6f692eb 100644 --- a/tests/test_config_manager_add.py +++ b/tests/test_config_manager_add.py @@ -1,5 +1,3 @@ -from email import message - import pytest import respx from httpx import Response @@ -189,7 +187,7 @@ async def test_add_with_target_no_cat(app: App, init_scheduler): from nonebot_bison.config import config from nonebot_bison.config_manager import add_sub_matcher, common_platform from nonebot_bison.platform import platform_manager - from nonebot_bison.platform.ncm_artist import NcmArtist + from nonebot_bison.platform.ncm import NcmArtist ncm_router = respx.get("https://music.163.com/api/artist/albums/32540734") ncm_router.mock(return_value=Response(200, json=get_json("ncm_siren.json"))) diff --git a/tests/test_render.py b/tests/test_render.py index 6939660..2404f1e 100644 --- a/tests/test_render.py +++ b/tests/test_render.py @@ -1,6 +1,7 @@ import typing import pytest +from httpx import AsyncClient from nonebug.app import App @@ -29,7 +30,7 @@ VuePress 由两部分组成:第一部分是一个极简静态网站生成器 async def test_arknights(app: App): from nonebot_bison.platform.arknights import Arknights - ak = Arknights() + ak = Arknights(AsyncClient()) res = await ak.parse( {"webUrl": "https://ak.hycdn.cn/announce/IOS/announcement/854_1644580545.html"} )