inject http client to scheduler

This commit is contained in:
felinae98 2022-10-16 17:19:55 +08:00
parent c8a4644e40
commit 74b5074f04
No known key found for this signature in database
GPG Key ID: 00C8B010587FF610
30 changed files with 461 additions and 377 deletions

View File

@ -1,6 +1,7 @@
import nonebot import nonebot
from nonebot.adapters.onebot.v11.bot import Bot from nonebot.adapters.onebot.v11.bot import Bot
from ..apis import check_sub_target
from ..config import ( from ..config import (
NoSuchSubscribeException, NoSuchSubscribeException,
NoSuchTargetException, NoSuchTargetException,
@ -8,7 +9,7 @@ from ..config import (
config, config,
) )
from ..config.db_config import SubscribeDupException from ..config.db_config import SubscribeDupException
from ..platform import check_sub_target, platform_manager from ..platform import platform_manager
from ..types import Target as T_Target from ..types import Target as T_Target
from ..types import WeightConfig from ..types import WeightConfig
from .jwt import pack_jwt from .jwt import pack_jwt

View File

@ -0,0 +1,12 @@
from .platform import platform_manager
from .scheduler import scheduler_dict
from .types import Target
async def check_sub_target(platform_name: str, target: Target):
platform = platform_manager[platform_name]
scheduler_conf_class = platform.scheduler
scheduler = scheduler_dict[scheduler_conf_class]
client = await scheduler.scheduler_config_obj.get_query_name_client()
return await platform_manager[platform_name].get_target_name(client, target)

View File

@ -15,9 +15,10 @@ from nonebot.permission import SUPERUSER
from nonebot.rule import to_me from nonebot.rule import to_me
from nonebot.typing import T_State from nonebot.typing import T_State
from .apis import check_sub_target
from .config import config from .config import config
from .config.db_config import SubscribeDupException from .config.db_config import SubscribeDupException
from .platform import Platform, check_sub_target, platform_manager from .platform import Platform, platform_manager
from .plugin_config import plugin_config from .plugin_config import plugin_config
from .types import Category, Target, User from .types import Category, Target, User
from .utils import parse_text from .utils import parse_text
@ -117,9 +118,7 @@ def do_add_sub(add_sub: Type[Matcher]):
) + "请输入订阅用户的id\n查询id获取方法请回复:“查询”" ) + "请输入订阅用户的id\n查询id获取方法请回复:“查询”"
else: else:
state["id"] = "default" state["id"] = "default"
state["name"] = await platform_manager[state["platform"]].get_target_name( state["name"] = await check_sub_target(state["platform"], Target(""))
Target("")
)
async def parse_id(event: MessageEvent, state: T_State): async def parse_id(event: MessageEvent, state: T_State):
if not isinstance(state["id"], Message): if not isinstance(state["id"], Message):

View File

@ -2,29 +2,24 @@ from collections import defaultdict
from importlib import import_module from importlib import import_module
from pathlib import Path from pathlib import Path
from pkgutil import iter_modules from pkgutil import iter_modules
from typing import DefaultDict, Type
from .platform import NoTargetGroup, Platform from .platform import Platform, make_no_target_group
_package_dir = str(Path(__file__).resolve().parent) _package_dir = str(Path(__file__).resolve().parent)
for (_, module_name, _) in iter_modules([_package_dir]): for (_, module_name, _) in iter_modules([_package_dir]):
import_module(f"{__name__}.{module_name}") import_module(f"{__name__}.{module_name}")
async def check_sub_target(target_type, target): _platform_list: DefaultDict[str, list[Type[Platform]]] = defaultdict(list)
return await platform_manager[target_type].get_target_name(target)
_platform_list = defaultdict(list)
for _platform in Platform.registry: for _platform in Platform.registry:
if not _platform.enabled: if not _platform.enabled:
continue continue
_platform_list[_platform.platform_name].append(_platform) _platform_list[_platform.platform_name].append(_platform)
platform_manager: dict[str, Platform] = dict() platform_manager: dict[str, Type[Platform]] = dict()
for name, platform_list in _platform_list.items(): for name, platform_list in _platform_list.items():
if len(platform_list) == 1: if len(platform_list) == 1:
platform_manager[name] = platform_list[0]() platform_manager[name] = platform_list[0]
else: else:
platform_manager[name] = NoTargetGroup( platform_manager[name] = make_no_target_group(platform_list)
[_platform() for _platform in platform_list]
)

View File

@ -1,12 +1,12 @@
import json import json
from typing import Any from typing import Any, Optional
from bs4 import BeautifulSoup as bs from bs4 import BeautifulSoup as bs
from httpx import AsyncClient
from nonebot.plugin import require from nonebot.plugin import require
from ..post import Post from ..post import Post
from ..types import Category, RawPost, Target from ..types import Category, RawPost, Target
from ..utils import http_client
from ..utils.scheduler_config import SchedulerConfig from ..utils.scheduler_config import SchedulerConfig
from .platform import CategoryNotSupport, NewMessage, StatusChange from .platform import CategoryNotSupport, NewMessage, StatusChange
@ -29,15 +29,17 @@ class Arknights(NewMessage):
scheduler = ArknightsSchedConf scheduler = ArknightsSchedConf
has_target = False has_target = False
async def get_target_name(self, _: Target) -> str: @classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return "明日方舟游戏信息" return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]: async def get_sub_list(self, _) -> list[RawPost]:
async with http_client() as client: raw_data = await self.client.get(
raw_data = await client.get( "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/announcement.meta.json"
"https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/announcement.meta.json" )
) return json.loads(raw_data.text)["announceList"]
return json.loads(raw_data.text)["announceList"]
def get_id(self, post: RawPost) -> Any: def get_id(self, post: RawPost) -> Any:
return post["announceId"] return post["announceId"]
@ -51,8 +53,7 @@ class Arknights(NewMessage):
async def parse(self, raw_post: RawPost) -> Post: async def parse(self, raw_post: RawPost) -> Post:
announce_url = raw_post["webUrl"] announce_url = raw_post["webUrl"]
text = "" text = ""
async with http_client() as client: raw_html = await self.client.get(announce_url)
raw_html = await client.get(announce_url)
soup = bs(raw_html.text, "html.parser") soup = bs(raw_html.text, "html.parser")
pics = [] pics = []
if soup.find("div", class_="standerd-container"): if soup.find("div", class_="standerd-container"):
@ -101,17 +102,19 @@ class AkVersion(StatusChange):
scheduler = ArknightsSchedConf scheduler = ArknightsSchedConf
has_target = False has_target = False
async def get_target_name(self, _: Target) -> str: @classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return "明日方舟游戏信息" return "明日方舟游戏信息"
async def get_status(self, _): async def get_status(self, _):
async with http_client() as client: res_ver = await self.client.get(
res_ver = await client.get( "https://ak-conf.hypergryph.com/config/prod/official/IOS/version"
"https://ak-conf.hypergryph.com/config/prod/official/IOS/version" )
) res_preanounce = await self.client.get(
res_preanounce = await client.get( "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json"
"https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json" )
)
res = res_ver.json() res = res_ver.json()
res.update(res_preanounce.json()) res.update(res_preanounce.json())
return res return res
@ -156,13 +159,17 @@ class MonsterSiren(NewMessage):
scheduler = ArknightsSchedConf scheduler = ArknightsSchedConf
has_target = False has_target = False
async def get_target_name(self, _: Target) -> str: @classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return "明日方舟游戏信息" return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]: async def get_sub_list(self, _) -> list[RawPost]:
async with http_client() as client: raw_data = await self.client.get(
raw_data = await client.get("https://monster-siren.hypergryph.com/api/news") "https://monster-siren.hypergryph.com/api/news"
return raw_data.json()["data"]["list"] )
return raw_data.json()["data"]["list"]
def get_id(self, post: RawPost) -> Any: def get_id(self, post: RawPost) -> Any:
return post["cid"] return post["cid"]
@ -175,16 +182,15 @@ class MonsterSiren(NewMessage):
async def parse(self, raw_post: RawPost) -> Post: async def parse(self, raw_post: RawPost) -> Post:
url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}' url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}'
async with http_client() as client: res = await self.client.get(
res = await client.get( f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}'
f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}' )
) raw_data = res.json()
raw_data = res.json() content = raw_data["data"]["content"]
content = raw_data["data"]["content"] content = content.replace("</p>", "</p>\n")
content = content.replace("</p>", "</p>\n") soup = bs(content, "html.parser")
soup = bs(content, "html.parser") imgs = list(map(lambda x: x["src"], soup("img")))
imgs = list(map(lambda x: x["src"], soup("img"))) text = f'{raw_post["title"]}\n{soup.text.strip()}'
text = f'{raw_post["title"]}\n{soup.text.strip()}'
return Post( return Post(
"monster-siren", "monster-siren",
text=text, text=text,
@ -207,15 +213,17 @@ class TerraHistoricusComic(NewMessage):
scheduler = ArknightsSchedConf scheduler = ArknightsSchedConf
has_target = False has_target = False
async def get_target_name(self, _: Target) -> str: @classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return "明日方舟游戏信息" return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]: async def get_sub_list(self, _) -> list[RawPost]:
async with http_client() as client: raw_data = await self.client.get(
raw_data = await client.get( "https://terra-historicus.hypergryph.com/api/recentUpdate"
"https://terra-historicus.hypergryph.com/api/recentUpdate" )
) return raw_data.json()["data"]
return raw_data.json()["data"]
def get_id(self, post: RawPost) -> Any: def get_id(self, post: RawPost) -> Any:
return f'{post["comicCid"]}/{post["episodeCid"]}' return f'{post["comicCid"]}/{post["episodeCid"]}'

View File

@ -5,6 +5,7 @@ from datetime import datetime, timedelta
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
import httpx import httpx
from httpx import AsyncClient
from nonebot.log import logger from nonebot.log import logger
from ..post import Post from ..post import Post
@ -20,35 +21,36 @@ class BilibiliSchedConf(SchedulerConfig):
schedule_type = "interval" schedule_type = "interval"
schedule_setting = {"seconds": 10} schedule_setting = {"seconds": 10}
_client_refresh_time: datetime
from .platform import CategoryNotSupport, NewMessage, StatusChange
class _BilibiliClient:
_http_client: httpx.AsyncClient
_client_refresh_time: Optional[datetime]
cookie_expire_time = timedelta(hours=5) cookie_expire_time = timedelta(hours=5)
def __init__(self):
self._client_refresh_time = datetime(
year=2000, month=1, day=1
) # an expired time
super().__init__()
async def _init_session(self): async def _init_session(self):
self._http_client = httpx.AsyncClient(**http_args) res = await self.default_http_client.get("https://www.bilibili.com/")
res = await self._http_client.get("https://www.bilibili.com/")
if res.status_code != 200: if res.status_code != 200:
logger.warning("unable to refresh temp cookie") logger.warning("unable to refresh temp cookie")
else: else:
self._client_refresh_time = datetime.now() self._client_refresh_time = datetime.now()
async def _refresh_client(self): async def _refresh_client(self):
if ( if datetime.now() - self._client_refresh_time > self.cookie_expire_time:
getattr(self, "_client_refresh_time", None) is None
or datetime.now() - self._client_refresh_time
> self.cookie_expire_time # type:ignore
or self._http_client is None
):
await self._init_session() await self._init_session()
async def get_client(self, target: Target) -> AsyncClient:
await self._refresh_client()
return await super().get_client(target)
class Bilibili(_BilibiliClient, NewMessage): async def get_query_name_client(self) -> AsyncClient:
await self._refresh_client()
return await super().get_query_name_client()
class Bilibili(NewMessage):
categories = { categories = {
1: "一般动态", 1: "一般动态",
@ -67,17 +69,11 @@ class Bilibili(_BilibiliClient, NewMessage):
has_target = True has_target = True
parse_target_promot = "请输入用户主页的链接" parse_target_promot = "请输入用户主页的链接"
def ensure_client(fun: Callable): # type:ignore @classmethod
@functools.wraps(fun) async def get_target_name(
async def wrapped(self, *args, **kwargs): cls, client: AsyncClient, target: Target
await self._refresh_client() ) -> Optional[str]:
return await fun(self, *args, **kwargs) res = await client.get(
return wrapped
@ensure_client
async def get_target_name(self, target: Target) -> Optional[str]:
res = await self._http_client.get(
"https://api.bilibili.com/x/space/acc/info", params={"mid": target} "https://api.bilibili.com/x/space/acc/info", params={"mid": target}
) )
res_data = json.loads(res.text) res_data = json.loads(res.text)
@ -85,18 +81,18 @@ class Bilibili(_BilibiliClient, NewMessage):
return None return None
return res_data["data"]["name"] return res_data["data"]["name"]
async def parse_target(self, target_text: str) -> Target: @classmethod
async def parse_target(cls, target_text: str) -> Target:
if re.match(r"\d+", target_text): if re.match(r"\d+", target_text):
return Target(target_text) return Target(target_text)
elif m := re.match(r"(?:https?://)?space\.bilibili\.com/(\d+)", target_text): elif m := re.match(r"(?:https?://)?space\.bilibili\.com/(\d+)", target_text):
return Target(m.group(1)) return Target(m.group(1))
else: else:
raise self.ParseTargetException() raise cls.ParseTargetException()
@ensure_client
async def get_sub_list(self, target: Target) -> list[RawPost]: async def get_sub_list(self, target: Target) -> list[RawPost]:
params = {"host_uid": target, "offset": 0, "need_top": 0} params = {"host_uid": target, "offset": 0, "need_top": 0}
res = await self._http_client.get( res = await self.client.get(
"https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history", "https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history",
params=params, params=params,
timeout=4.0, timeout=4.0,
@ -202,7 +198,7 @@ class Bilibili(_BilibiliClient, NewMessage):
return Post("bilibili", text=text, url=url, pics=pic, target_name=target_name) return Post("bilibili", text=text, url=url, pics=pic, target_name=target_name)
class Bilibililive(_BilibiliClient, StatusChange): class Bilibililive(StatusChange):
# Author : Sichongzou # Author : Sichongzou
# Date : 2022-5-18 8:54 # Date : 2022-5-18 8:54
# Description : bilibili开播提醒 # Description : bilibili开播提醒
@ -216,17 +212,11 @@ class Bilibililive(_BilibiliClient, StatusChange):
name = "Bilibili直播" name = "Bilibili直播"
has_target = True has_target = True
def ensure_client(fun: Callable): # type:ignore @classmethod
@functools.wraps(fun) async def get_target_name(
async def wrapped(self, *args, **kwargs): cls, client: AsyncClient, target: Target
await self._refresh_client() ) -> Optional[str]:
return await fun(self, *args, **kwargs) res = await client.get(
return wrapped
@ensure_client
async def get_target_name(self, target: Target) -> Optional[str]:
res = await self._http_client.get(
"https://api.bilibili.com/x/space/acc/info", params={"mid": target} "https://api.bilibili.com/x/space/acc/info", params={"mid": target}
) )
res_data = json.loads(res.text) res_data = json.loads(res.text)
@ -234,10 +224,9 @@ class Bilibililive(_BilibiliClient, StatusChange):
return None return None
return res_data["data"]["name"] return res_data["data"]["name"]
@ensure_client
async def get_status(self, target: Target): async def get_status(self, target: Target):
params = {"mid": target} params = {"mid": target}
res = await self._http_client.get( res = await self.client.get(
"https://api.bilibili.com/x/space/acc/info", "https://api.bilibili.com/x/space/acc/info",
params=params, params=params,
timeout=4.0, timeout=4.0,
@ -279,7 +268,7 @@ class Bilibililive(_BilibiliClient, StatusChange):
) )
class BilibiliBangumi(_BilibiliClient, StatusChange): class BilibiliBangumi(StatusChange):
categories = {} categories = {}
platform_name = "bilibili-bangumi" platform_name = "bilibili-bangumi"
@ -293,23 +282,18 @@ class BilibiliBangumi(_BilibiliClient, StatusChange):
_url = "https://api.bilibili.com/pgc/review/user" _url = "https://api.bilibili.com/pgc/review/user"
def ensure_client(fun: Callable): # type:ignore @classmethod
@functools.wraps(fun) async def get_target_name(
async def wrapped(self, *args, **kwargs): cls, client: AsyncClient, target: Target
await self._refresh_client() ) -> Optional[str]:
return await fun(self, *args, **kwargs) res = await client.get(cls._url, params={"media_id": target})
return wrapped
@ensure_client
async def get_target_name(self, target: Target) -> Optional[str]:
res = await self._http_client.get(self._url, params={"media_id": target})
res_data = res.json() res_data = res.json()
if res_data["code"]: if res_data["code"]:
return None return None
return res_data["result"]["media"]["title"] return res_data["result"]["media"]["title"]
async def parse_target(self, target_string: str) -> Target: @classmethod
async def parse_target(cls, target_string: str) -> Target:
if re.match(r"\d+", target_string): if re.match(r"\d+", target_string):
return Target(target_string) return Target(target_string)
elif m := re.match(r"md(\d+)", target_string): elif m := re.match(r"md(\d+)", target_string):
@ -318,11 +302,10 @@ class BilibiliBangumi(_BilibiliClient, StatusChange):
r"(?:https?://)?www\.bilibili\.com/bangumi/media/md(\d+)/", target_string r"(?:https?://)?www\.bilibili\.com/bangumi/media/md(\d+)/", target_string
): ):
return Target(m.group(1)) return Target(m.group(1))
raise self.ParseTargetException() raise cls.ParseTargetException()
@ensure_client
async def get_status(self, target: Target): async def get_status(self, target: Target):
res = await self._http_client.get( res = await self.client.get(
self._url, self._url,
params={"media_id": target}, params={"media_id": target},
timeout=4.0, timeout=4.0,
@ -343,9 +326,8 @@ class BilibiliBangumi(_BilibiliClient, StatusChange):
else: else:
return [] return []
@ensure_client
async def parse(self, raw_post: RawPost) -> Post: async def parse(self, raw_post: RawPost) -> Post:
detail_res = await self._http_client.get( detail_res = await self.client.get(
f'http://api.bilibili.com/pgc/view/web/season?season_id={raw_post["season_id"]}' f'http://api.bilibili.com/pgc/view/web/season?season_id={raw_post["season_id"]}'
) )
detail_dict = detail_res.json() detail_dict = detail_res.json()

View File

@ -1,4 +1,6 @@
from typing import Any from typing import Any, Optional
from httpx import AsyncClient
from ..post import Post from ..post import Post
from ..types import RawPost, Target from ..types import RawPost, Target
@ -18,7 +20,10 @@ class FF14(NewMessage):
scheduler = scheduler("interval", {"seconds": 60}) scheduler = scheduler("interval", {"seconds": 60})
has_target = False has_target = False
async def get_target_name(self, _: Target) -> str: @classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return "最终幻想XIV官方公告" return "最终幻想XIV官方公告"
async def get_sub_list(self, _) -> list[RawPost]: async def get_sub_list(self, _) -> list[RawPost]:

View File

@ -1,9 +1,10 @@
import re import re
import time import time
from typing import Literal from typing import Literal, Optional
import httpx import httpx
from bs4 import BeautifulSoup, NavigableString, Tag from bs4 import BeautifulSoup, NavigableString, Tag
from httpx import AsyncClient
from ..post import Post from ..post import Post
from ..types import Category, RawPost, Target from ..types import Category, RawPost, Target
@ -42,8 +43,11 @@ class McbbsNews(NewMessage):
scheduler = scheduler("interval", {"hours": 1}) scheduler = scheduler("interval", {"hours": 1})
has_target = False has_target = False
async def get_target_name(self, _: Target) -> str: @classmethod
return self.name async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return cls.name
async def get_sub_list(self, _: Target) -> list[RawPost]: async def get_sub_list(self, _: Target) -> list[RawPost]:
url = "https://www.mcbbs.net/forum-news-1.html" url = "https://www.mcbbs.net/forum-news-1.html"

View File

@ -0,0 +1,146 @@
import re
from typing import Any, Optional
from httpx import AsyncClient
from ..post import Post
from ..types import RawPost, Target
from ..utils import SchedulerConfig, http_client
from .platform import NewMessage
class NcmSchedConf(SchedulerConfig):
name = "music.163.com"
schedule_type = "interval"
schedule_setting = {"minutes": 1}
class NcmArtist(NewMessage):
categories = {}
platform_name = "ncm-artist"
enable_tag = False
enabled = True
is_common = True
scheduler = NcmSchedConf
name = "网易云-歌手"
has_target = True
parse_target_promot = "请输入歌手主页包含数字ID的链接"
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async with http_client() as client:
res = await client.get(
"https://music.163.com/api/artist/albums/{}".format(target),
headers={"Referer": "https://music.163.com/"},
)
res_data = res.json()
if res_data["code"] != 200:
return
return res_data["artist"]["name"]
@classmethod
async def parse_target(cls, target_text: str) -> Target:
if re.match(r"^\d+$", target_text):
return Target(target_text)
elif match := re.match(
r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text
):
return Target(match.group(1))
else:
raise cls.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]:
async with http_client() as client:
res = await client.get(
"https://music.163.com/api/artist/albums/{}".format(target),
headers={"Referer": "https://music.163.com/"},
)
res_data = res.json()
if res_data["code"] != 200:
return []
else:
return res_data["hotAlbums"]
def get_id(self, post: RawPost) -> Any:
return post["id"]
def get_date(self, post: RawPost) -> int:
return post["publishTime"] // 1000
async def parse(self, raw_post: RawPost) -> Post:
text = "新专辑发布:{}".format(raw_post["name"])
target_name = raw_post["artist"]["name"]
pics = [raw_post["picUrl"]]
url = "https://music.163.com/#/album?id={}".format(raw_post["id"])
return Post(
"ncm-artist", text=text, url=url, pics=pics, target_name=target_name
)
class NcmRadio(NewMessage):
categories = {}
platform_name = "ncm-radio"
enable_tag = False
enabled = True
is_common = False
scheduler = NcmSchedConf
name = "网易云-电台"
has_target = True
parse_target_promot = "请输入主播电台主页包含数字ID的链接"
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async with http_client() as client:
res = await client.post(
"http://music.163.com/api/dj/program/byradio",
headers={"Referer": "https://music.163.com/"},
data={"radioId": target, "limit": 1000, "offset": 0},
)
res_data = res.json()
if res_data["code"] != 200 or res_data["programs"] == 0:
return
return res_data["programs"][0]["radio"]["name"]
@classmethod
async def parse_target(cls, target_text: str) -> Target:
if re.match(r"^\d+$", target_text):
return Target(target_text)
elif match := re.match(
r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text
):
return Target(match.group(1))
else:
raise cls.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]:
async with http_client() as client:
res = await client.post(
"http://music.163.com/api/dj/program/byradio",
headers={"Referer": "https://music.163.com/"},
data={"radioId": target, "limit": 1000, "offset": 0},
)
res_data = res.json()
if res_data["code"] != 200:
return []
else:
return res_data["programs"]
def get_id(self, post: RawPost) -> Any:
return post["id"]
def get_date(self, post: RawPost) -> int:
return post["createTime"] // 1000
async def parse(self, raw_post: RawPost) -> Post:
text = "网易云电台更新:{}".format(raw_post["name"])
target_name = raw_post["radio"]["name"]
pics = [raw_post["coverUrl"]]
url = "https://music.163.com/#/program/{}".format(raw_post["id"])
return Post("ncm-radio", text=text, url=url, pics=pics, target_name=target_name)

View File

@ -1,75 +0,0 @@
import re
from typing import Any, Optional
from ..post import Post
from ..types import RawPost, Target
from ..utils import SchedulerConfig, http_client
from .platform import NewMessage
class NcmSchedConf(SchedulerConfig):
name = "music.163.com"
schedule_type = "interval"
schedule_setting = {"minutes": 1}
class NcmArtist(NewMessage):
categories = {}
platform_name = "ncm-artist"
enable_tag = False
enabled = True
is_common = True
scheduler = NcmSchedConf
name = "网易云-歌手"
has_target = True
parse_target_promot = "请输入歌手主页包含数字ID的链接"
async def get_target_name(self, target: Target) -> Optional[str]:
async with http_client() as client:
res = await client.get(
"https://music.163.com/api/artist/albums/{}".format(target),
headers={"Referer": "https://music.163.com/"},
)
res_data = res.json()
if res_data["code"] != 200:
return
return res_data["artist"]["name"]
async def parse_target(self, target_text: str) -> Target:
if re.match(r"^\d+$", target_text):
return Target(target_text)
elif match := re.match(
r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text
):
return Target(match.group(1))
else:
raise self.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]:
async with http_client() as client:
res = await client.get(
"https://music.163.com/api/artist/albums/{}".format(target),
headers={"Referer": "https://music.163.com/"},
)
res_data = res.json()
if res_data["code"] != 200:
return []
else:
return res_data["hotAlbums"]
def get_id(self, post: RawPost) -> Any:
return post["id"]
def get_date(self, post: RawPost) -> int:
return post["publishTime"] // 1000
async def parse(self, raw_post: RawPost) -> Post:
text = "新专辑发布:{}".format(raw_post["name"])
target_name = raw_post["artist"]["name"]
pics = [raw_post["picUrl"]]
url = "https://music.163.com/#/album?id={}".format(raw_post["id"])
return Post(
"ncm-artist", text=text, url=url, pics=pics, target_name=target_name
)

View File

@ -1,69 +0,0 @@
import re
from typing import Any, Optional
from ..post import Post
from ..types import RawPost, Target
from ..utils import http_client
from .ncm_artist import NcmSchedConf
from .platform import NewMessage
class NcmRadio(NewMessage):
categories = {}
platform_name = "ncm-radio"
enable_tag = False
enabled = True
is_common = False
scheduler = NcmSchedConf
name = "网易云-电台"
has_target = True
parse_target_promot = "请输入主播电台主页包含数字ID的链接"
async def get_target_name(self, target: Target) -> Optional[str]:
async with http_client() as client:
res = await client.post(
"http://music.163.com/api/dj/program/byradio",
headers={"Referer": "https://music.163.com/"},
data={"radioId": target, "limit": 1000, "offset": 0},
)
res_data = res.json()
if res_data["code"] != 200 or res_data["programs"] == 0:
return
return res_data["programs"][0]["radio"]["name"]
async def parse_target(self, target_text: str) -> Target:
if re.match(r"^\d+$", target_text):
return Target(target_text)
elif match := re.match(
r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text
):
return Target(match.group(1))
else:
raise self.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]:
async with http_client() as client:
res = await client.post(
"http://music.163.com/api/dj/program/byradio",
headers={"Referer": "https://music.163.com/"},
data={"radioId": target, "limit": 1000, "offset": 0},
)
res_data = res.json()
if res_data["code"] != 200:
return []
else:
return res_data["programs"]
def get_id(self, post: RawPost) -> Any:
return post["id"]
def get_date(self, post: RawPost) -> int:
return post["createTime"] // 1000
async def parse(self, raw_post: RawPost) -> Post:
text = "网易云电台更新:{}".format(raw_post["name"])
target_name = raw_post["radio"]["name"]
pics = [raw_post["coverUrl"]]
url = "https://music.163.com/#/program/{}".format(raw_post["id"])
return Post("ncm-radio", text=text, url=url, pics=pics, target_name=target_name)

View File

@ -1,12 +1,14 @@
import json import json
import ssl import ssl
import time import time
import typing
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Collection, Literal, Optional, Type from typing import Any, Collection, Optional, Type
import httpx import httpx
from httpx import AsyncClient
from nonebot.log import logger from nonebot.log import logger
from ..plugin_config import plugin_config from ..plugin_config import plugin_config
@ -34,11 +36,23 @@ class RegistryMeta(type):
super().__init__(name, bases, namespace, **kwargs) super().__init__(name, bases, namespace, **kwargs)
class RegistryABCMeta(RegistryMeta, ABC): class PlatformMeta(RegistryMeta):
categories: dict[Category, str]
def __init__(cls, name, bases, namespace, **kwargs):
cls.reverse_category = {}
if hasattr(cls, "categories") and cls.categories:
for key, val in cls.categories.items():
cls.reverse_category[val] = key
super().__init__(name, bases, namespace, **kwargs)
class PlatformABCMeta(PlatformMeta, ABC):
... ...
class Platform(metaclass=RegistryABCMeta, base=True): class Platform(metaclass=PlatformABCMeta, base=True):
scheduler: Type[SchedulerConfig] scheduler: Type[SchedulerConfig]
is_common: bool is_common: bool
@ -50,9 +64,15 @@ class Platform(metaclass=RegistryABCMeta, base=True):
store: dict[Target, Any] store: dict[Target, Any]
platform_name: str platform_name: str
parse_target_promot: Optional[str] = None parse_target_promot: Optional[str] = None
registry: list[Type["Platform"]]
client: AsyncClient
reverse_category: dict[str, Category]
@classmethod
@abstractmethod @abstractmethod
async def get_target_name(self, target: Target) -> Optional[str]: async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
... ...
@abstractmethod @abstractmethod
@ -88,17 +108,16 @@ class Platform(metaclass=RegistryABCMeta, base=True):
"actually function called" "actually function called"
return await self.parse(raw_post) return await self.parse(raw_post)
def __init__(self): def __init__(self, client: AsyncClient):
super().__init__() super().__init__()
self.reverse_category = {}
for key, val in self.categories.items():
self.reverse_category[val] = key
self.store = dict() self.store = dict()
self.client = client
class ParseTargetException(Exception): class ParseTargetException(Exception):
pass pass
async def parse_target(self, target_string: str) -> Target: @classmethod
async def parse_target(cls, target_string: str) -> Target:
return Target(target_string) return Target(target_string)
@abstractmethod @abstractmethod
@ -188,8 +207,8 @@ class Platform(metaclass=RegistryABCMeta, base=True):
class MessageProcess(Platform, abstract=True): class MessageProcess(Platform, abstract=True):
"General message process fetch, parse, filter progress" "General message process fetch, parse, filter progress"
def __init__(self): def __init__(self, client: AsyncClient):
super().__init__() super().__init__(client)
self.parse_cache: dict[Any, Post] = dict() self.parse_cache: dict[Any, Post] = dict()
@abstractmethod @abstractmethod
@ -362,55 +381,82 @@ class SimplePost(MessageProcess, abstract=True):
return res return res
class NoTargetGroup(Platform, abstract=True): def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
enable_tag = False
if typing.TYPE_CHECKING:
class NoTargetGroup(Platform, abstract=True):
platform_list: list[Type[Platform]]
platform_obj_list: list[Platform]
DUMMY_STR = "_DUMMY" DUMMY_STR = "_DUMMY"
enabled = True
has_target = False
def __init__(self, platform_list: list[Platform]): platform_name = platform_list[0].platform_name
self.platform_list = platform_list name = DUMMY_STR
self.platform_name = platform_list[0].platform_name categories_keys = set()
name = self.DUMMY_STR categories = {}
self.categories = {} scheduler = platform_list[0].scheduler
categories_keys = set()
self.scheduler = platform_list[0].scheduler
for platform in platform_list:
if platform.has_target:
raise RuntimeError(
"Platform {} should have no target".format(platform.name)
)
if name == self.DUMMY_STR:
name = platform.name
elif name != platform.name:
raise RuntimeError(
"Platform name for {} not fit".format(self.platform_name)
)
platform_category_key_set = set(platform.categories.keys())
if platform_category_key_set & categories_keys:
raise RuntimeError(
"Platform categories for {} duplicate".format(self.platform_name)
)
categories_keys |= platform_category_key_set
self.categories.update(platform.categories)
if platform.scheduler != self.scheduler:
raise RuntimeError(
"Platform scheduler for {} not fit".format(self.platform_name)
)
self.name = name
self.is_common = platform_list[0].is_common
super().__init__()
def __str__(self): for platform in platform_list:
if platform.has_target:
raise RuntimeError(
"Platform {} should have no target".format(platform.name)
)
if name == DUMMY_STR:
name = platform.name
elif name != platform.name:
raise RuntimeError("Platform name for {} not fit".format(platform_name))
platform_category_key_set = set(platform.categories.keys())
if platform_category_key_set & categories_keys:
raise RuntimeError(
"Platform categories for {} duplicate".format(platform_name)
)
categories_keys |= platform_category_key_set
categories.update(platform.categories)
if platform.scheduler != scheduler:
raise RuntimeError(
"Platform scheduler for {} not fit".format(platform_name)
)
def __init__(self: "NoTargetGroup", client: AsyncClient):
Platform.__init__(self, client)
self.platform_obj_list = []
for platform_class in self.platform_list:
self.platform_obj_list.append(platform_class(client))
def __str__(self: "NoTargetGroup") -> str:
return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]" return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]"
async def get_target_name(self, _): @classmethod
return await self.platform_list[0].get_target_name(_) async def get_target_name(cls, client: AsyncClient, target: Target):
return await platform_list[0].get_target_name(client, target)
async def fetch_new_post(self, target, users): async def fetch_new_post(
self: "NoTargetGroup", target: Target, users: list[UserSubInfo]
):
res = defaultdict(list) res = defaultdict(list)
for platform in self.platform_list: for platform in self.platform_obj_list:
platform_res = await platform.fetch_new_post(target=target, users=users) platform_res = await platform.fetch_new_post(target=target, users=users)
for user, posts in platform_res: for user, posts in platform_res:
res[user].extend(posts) res[user].extend(posts)
return [[key, val] for key, val in res.items()] return [[key, val] for key, val in res.items()]
return type(
"NoTargetGroup",
(Platform,),
{
"platform_list": platform_list,
"platform_name": platform_list[0].platform_name,
"name": name,
"categories": categories,
"scheduler": scheduler,
"is_common": platform_list[0].is_common,
"enabled": True,
"has_target": False,
"enable_tag": False,
"__init__": __init__,
"get_target_name": get_target_name,
"fetch_new_post": fetch_new_post,
},
abstract=True,
)

View File

@ -3,6 +3,7 @@ from typing import Any, Optional
import feedparser import feedparser
from bs4 import BeautifulSoup as bs from bs4 import BeautifulSoup as bs
from httpx import AsyncClient
from ..post import Post from ..post import Post
from ..types import RawPost, Target from ..types import RawPost, Target
@ -21,7 +22,10 @@ class Rss(NewMessage):
scheduler = scheduler("interval", {"seconds": 30}) scheduler = scheduler("interval", {"seconds": 30})
has_target = True has_target = True
async def get_target_name(self, target: Target) -> Optional[str]: @classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async with http_client() as client: async with http_client() as client:
res = await client.get(target, timeout=10.0) res = await client.get(target, timeout=10.0)
feed = feedparser.parse(res.text) feed = feedparser.parse(res.text)

View File

@ -5,6 +5,7 @@ from datetime import datetime
from typing import Any, Optional from typing import Any, Optional
from bs4 import BeautifulSoup as bs from bs4 import BeautifulSoup as bs
from httpx import AsyncClient
from nonebot.log import logger from nonebot.log import logger
from ..post import Post from ..post import Post
@ -36,7 +37,10 @@ class Weibo(NewMessage):
has_target = True has_target = True
parse_target_promot = "请输入用户主页包含数字UID的链接" parse_target_promot = "请输入用户主页包含数字UID的链接"
async def get_target_name(self, target: Target) -> Optional[str]: @classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async with http_client() as client: async with http_client() as client:
param = {"containerid": "100505" + target} param = {"containerid": "100505" + target}
res = await client.get( res = await client.get(
@ -48,14 +52,15 @@ class Weibo(NewMessage):
else: else:
return None return None
async def parse_target(self, target_text: str) -> Target: @classmethod
async def parse_target(cls, target_text: str) -> Target:
if re.match(r"\d+", target_text): if re.match(r"\d+", target_text):
return Target(target_text) return Target(target_text)
elif match := re.match(r"(?:https?://)?weibo\.com/u/(\d+)", target_text): elif match := re.match(r"(?:https?://)?weibo\.com/u/(\d+)", target_text):
# 都2202年了应该不会有http了吧不过还是防一手 # 都2202年了应该不会有http了吧不过还是防一手
return Target(match.group(1)) return Target(match.group(1))
else: else:
raise self.ParseTargetException() raise cls.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]: async def get_sub_list(self, target: Target) -> list[RawPost]:
async with http_client() as client: async with http_client() as client:

View File

@ -7,7 +7,6 @@ from nonebot.log import logger
from ..config import config from ..config import config
from ..platform import platform_manager from ..platform import platform_manager
from ..platform.platform import Platform
from ..send import send_msgs from ..send import send_msgs
from ..types import Target from ..types import Target
from ..utils import SchedulerConfig from ..utils import SchedulerConfig
@ -36,6 +35,7 @@ class Scheduler:
logger.error(f"scheduler config [{self.name}] not found, exiting") logger.error(f"scheduler config [{self.name}] not found, exiting")
raise RuntimeError(f"{self.name} not found") raise RuntimeError(f"{self.name} not found")
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.scheduler_config_obj = self.scheduler_config()
self.schedulable_list = [] self.schedulable_list = []
for platform_name, target in schedulables: for platform_name, target in schedulables:
self.schedulable_list.append( self.schedulable_list.append(
@ -86,7 +86,10 @@ class Scheduler:
send_userinfo_list = await config.get_platform_target_subscribers( send_userinfo_list = await config.get_platform_target_subscribers(
schedulable.platform_name, schedulable.target schedulable.platform_name, schedulable.target
) )
to_send = await platform_manager[schedulable.platform_name].do_fetch_new_post( platform_obj = platform_manager[schedulable.platform_name](
await self.scheduler_config_obj.get_client(schedulable.target)
)
to_send = await platform_obj.do_fetch_new_post(
schedulable.target, send_userinfo_list schedulable.target, send_userinfo_list
) )
if not to_send: if not to_send:

View File

@ -1,5 +1,10 @@
from typing import Literal, Type from typing import Literal, Type
from httpx import AsyncClient
from ..types import Target
from .http import http_client
class SchedulerConfig: class SchedulerConfig:
@ -10,6 +15,15 @@ class SchedulerConfig:
def __str__(self): def __str__(self):
return f"[{self.name}]-{self.name}-{self.schedule_setting}" return f"[{self.name}]-{self.name}-{self.schedule_setting}"
def __init__(self):
self.default_http_client = http_client()
async def get_client(self, target: Target) -> AsyncClient:
return self.default_http_client
async def get_query_name_client(self) -> AsyncClient:
return self.default_http_client
def scheduler( def scheduler(
schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict

View File

@ -1,6 +1,6 @@
import pytest import pytest
import respx import respx
from httpx import Response from httpx import AsyncClient, Response
from nonebug.app import App from nonebug.app import App
from .utils import get_file, get_json from .utils import get_file, get_json
@ -10,7 +10,7 @@ from .utils import get_file, get_json
def arknights(app: App): def arknights(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
return platform_manager["arknights"] return platform_manager["arknights"](AsyncClient())
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -3,7 +3,7 @@ from datetime import datetime
import pytest import pytest
import respx import respx
from httpx import Response from httpx import AsyncClient, Response
from nonebug.app import App from nonebug.app import App
from pytz import timezone from pytz import timezone
@ -23,7 +23,7 @@ if typing.TYPE_CHECKING:
def bilibili(app: App): def bilibili(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
return platform_manager["bilibili"] return platform_manager["bilibili"](AsyncClient())
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -2,7 +2,7 @@ import typing
import pytest import pytest
import respx import respx
from httpx import Response from httpx import AsyncClient, Response
from nonebug.app import App from nonebug.app import App
from .utils import get_json from .utils import get_json
@ -15,7 +15,7 @@ if typing.TYPE_CHECKING:
def bili_bangumi(app: App): def bili_bangumi(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
return platform_manager["bilibili-bangumi"] return platform_manager["bilibili-bangumi"](AsyncClient())
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -1,6 +1,6 @@
import pytest import pytest
import respx import respx
from httpx import Response from httpx import AsyncClient, Response
from nonebug.app import App from nonebug.app import App
from .utils import get_json from .utils import get_json
@ -10,7 +10,7 @@ from .utils import get_json
def bili_live(app: App): def bili_live(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
return platform_manager["bilibili-live"] return platform_manager["bilibili-live"](AsyncClient())
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -1,6 +1,6 @@
import pytest import pytest
import respx import respx
from httpx import Response from httpx import AsyncClient, Response
from nonebug.app import App from nonebug.app import App
from .utils import get_json from .utils import get_json
@ -10,7 +10,7 @@ from .utils import get_json
def ff14(app: App): def ff14(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
return platform_manager["ff14"] return platform_manager["ff14"](AsyncClient())
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -1,6 +1,6 @@
import pytest import pytest
import respx import respx
from httpx import Response from httpx import AsyncClient, Response
from nonebug.app import App from nonebug.app import App
from .utils import get_file, get_json from .utils import get_file, get_json
@ -10,7 +10,7 @@ from .utils import get_file, get_json
def mcbbsnews(app: App): def mcbbsnews(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
return platform_manager["mcbbsnews"] return platform_manager["mcbbsnews"](AsyncClient())
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -3,20 +3,20 @@ import typing
import pytest import pytest
import respx import respx
from httpx import Response from httpx import AsyncClient, Response
from nonebug.app import App from nonebug.app import App
from .utils import get_json from .utils import get_json
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from nonebot_bison.platform.ncm_artist import NcmArtist from nonebot_bison.platform.ncm import NcmArtist
@pytest.fixture @pytest.fixture
def ncm_artist(app: App): def ncm_artist(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
return platform_manager["ncm-artist"] return platform_manager["ncm-artist"](AsyncClient())
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -3,20 +3,20 @@ import typing
import pytest import pytest
import respx import respx
from httpx import Response from httpx import AsyncClient, Response
from nonebug.app import App from nonebug.app import App
from .utils import get_json from .utils import get_json
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from nonebot_bison.platform.ncm_radio import NcmRadio from nonebot_bison.platform.ncm import NcmRadio
@pytest.fixture @pytest.fixture
def ncm_radio(app: App): def ncm_radio(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
return platform_manager["ncm-radio"] return platform_manager["ncm-radio"](AsyncClient())
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -53,12 +53,12 @@ def mock_platform_without_cats_tags(app: App):
categories = {} categories = {}
has_target = True has_target = True
def __init__(self): def __init__(self, client):
self.sub_index = 0 self.sub_index = 0
super().__init__() super().__init__(client)
@staticmethod @classmethod
async def get_target_name(_: "Target"): async def get_target_name(cls, client, _: "Target"):
return "MockPlatform" return "MockPlatform"
def get_id(self, post: "RawPost") -> Any: def get_id(self, post: "RawPost") -> Any:
@ -82,7 +82,7 @@ def mock_platform_without_cats_tags(app: App):
else: else:
return raw_post_list_2 return raw_post_list_2
return MockPlatform() return MockPlatform(None)
@pytest.fixture @pytest.fixture
@ -112,9 +112,9 @@ def mock_platform(app: App):
Category(2): "视频", Category(2): "视频",
} }
def __init__(self): def __init__(self, client):
self.sub_index = 0 self.sub_index = 0
super().__init__() super().__init__(client)
@staticmethod @staticmethod
async def get_target_name(_: "Target"): async def get_target_name(_: "Target"):
@ -147,7 +147,7 @@ def mock_platform(app: App):
else: else:
return raw_post_list_2 return raw_post_list_2
return MockPlatform() return MockPlatform(None)
@pytest.fixture @pytest.fixture
@ -180,9 +180,9 @@ def mock_platform_no_target(app: App, mock_scheduler_conf):
has_target = False has_target = False
categories = {Category(1): "转发", Category(2): "视频", Category(3): "不支持"} categories = {Category(1): "转发", Category(2): "视频", Category(3): "不支持"}
def __init__(self): def __init__(self, client):
self.sub_index = 0 self.sub_index = 0
super().__init__() super().__init__(client)
@staticmethod @staticmethod
async def get_target_name(_: "Target"): async def get_target_name(_: "Target"):
@ -217,7 +217,7 @@ def mock_platform_no_target(app: App, mock_scheduler_conf):
else: else:
return raw_post_list_2 return raw_post_list_2
return MockPlatform() return MockPlatform
@pytest.fixture @pytest.fixture
@ -241,12 +241,12 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf):
Category(5): "leixing5", Category(5): "leixing5",
} }
def __init__(self): def __init__(self, client):
self.sub_index = 0 self.sub_index = 0
super().__init__() super().__init__(client)
@staticmethod @classmethod
async def get_target_name(_: "Target"): async def get_target_name(cls, client, _: "Target"):
return "MockPlatform" return "MockPlatform"
def get_id(self, post: "RawPost") -> Any: def get_id(self, post: "RawPost") -> Any:
@ -284,7 +284,7 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf):
else: else:
return list_2 return list_2
return MockPlatform() return MockPlatform
@pytest.fixture @pytest.fixture
@ -308,9 +308,9 @@ def mock_status_change(app: App):
Category(2): "视频", Category(2): "视频",
} }
def __init__(self): def __init__(self, client):
self.sub_index = 0 self.sub_index = 0
super().__init__() super().__init__(client)
async def get_status(self, _: "Target"): async def get_status(self, _: "Target"):
if self.sub_index == 0: if self.sub_index == 0:
@ -335,7 +335,7 @@ def mock_status_change(app: App):
def get_category(self, raw_post): def get_category(self, raw_post):
return raw_post["cat"] return raw_post["cat"]
return MockPlatform() return MockPlatform(None)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -388,6 +388,7 @@ async def test_new_message_target(mock_platform, user_info_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_message_no_target(mock_platform_no_target, user_info_factory): async def test_new_message_no_target(mock_platform_no_target, user_info_factory):
mock_platform_no_target = mock_platform_no_target(None)
res1 = await mock_platform_no_target.fetch_new_post( res1 = await mock_platform_no_target.fetch_new_post(
"dummy", [user_info_factory([1, 2], [])] "dummy", [user_info_factory([1, 2], [])]
) )
@ -457,11 +458,14 @@ async def test_group(
user_info_factory, user_info_factory,
): ):
from nonebot_bison.platform.platform import NoTargetGroup from nonebot_bison.platform.platform import make_no_target_group
from nonebot_bison.post import Post from nonebot_bison.post import Post
from nonebot_bison.types import Category, RawPost, Tag, Target from nonebot_bison.types import Category, RawPost, Tag, Target
group_platform = NoTargetGroup([mock_platform_no_target, mock_platform_no_target_2]) group_platform_class = make_no_target_group(
[mock_platform_no_target, mock_platform_no_target_2]
)
group_platform = group_platform_class(None)
res1 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])]) res1 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])])
assert len(res1) == 0 assert len(res1) == 0
res2 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])]) res2 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])])

View File

@ -1,4 +1,5 @@
import pytest import pytest
from httpx import AsyncClient
from nonebug.app import App from nonebug.app import App
from .utils import get_json from .utils import get_json
@ -14,7 +15,7 @@ def test_cases():
async def test_filter_user_custom_tag(app: App, test_cases): async def test_filter_user_custom_tag(app: App, test_cases):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
bilibili = platform_manager["bilibili"] bilibili = platform_manager["bilibili"](AsyncClient())
for case in test_cases: for case in test_cases:
res = bilibili.is_banned_post(**case["case"]) res = bilibili.is_banned_post(**case["case"])
assert res == case["result"] assert res == case["result"]
@ -25,7 +26,7 @@ async def test_filter_user_custom_tag(app: App, test_cases):
async def test_tag_separator(app: App): async def test_tag_separator(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
bilibili = platform_manager["bilibili"] bilibili = platform_manager["bilibili"](AsyncClient())
tags = ["~111", "222", "333", "~444", "555"] tags = ["~111", "222", "333", "~444", "555"]
res = bilibili.tag_separator(tags) res = bilibili.tag_separator(tags)
assert res[0] == ["222", "333", "555"] assert res[0] == ["222", "333", "555"]

View File

@ -4,7 +4,7 @@ from datetime import datetime
import feedparser import feedparser
import pytest import pytest
import respx import respx
from httpx import Response from httpx import AsyncClient, Response
from nonebug.app import App from nonebug.app import App
from pytz import timezone from pytz import timezone
@ -18,7 +18,7 @@ if typing.TYPE_CHECKING:
def weibo(app: App): def weibo(app: App):
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
return platform_manager["weibo"] return platform_manager["weibo"](AsyncClient())
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -35,7 +35,7 @@ async def test_get_name(weibo):
profile_router.mock( profile_router.mock(
return_value=Response(200, json=get_json("weibo_ak_profile.json")) return_value=Response(200, json=get_json("weibo_ak_profile.json"))
) )
name = await weibo.get_target_name("6279793937") name = await weibo.get_target_name(AsyncClient(), "6279793937")
assert name == "明日方舟Arknights" assert name == "明日方舟Arknights"

View File

@ -10,7 +10,7 @@ from .utils import BotReply, fake_admin_user, fake_group_message_event
# 选择platform阶段中止 # 选择platform阶段中止
@pytest.mark.asyncio @pytest.mark.asyncio
@respx.mock @respx.mock
async def test_abort_add_on_platform(app: App, db_migration): async def test_abort_add_on_platform(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message from nonebot.adapters.onebot.v11.message import Message
from nonebot_bison.config_manager import add_sub_matcher, common_platform from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -57,7 +57,7 @@ async def test_abort_add_on_platform(app: App, db_migration):
# 输入id阶段中止 # 输入id阶段中止
@pytest.mark.asyncio @pytest.mark.asyncio
@respx.mock @respx.mock
async def test_abort_add_on_id(app: App, db_migration): async def test_abort_add_on_id(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message from nonebot.adapters.onebot.v11.message import Message
from nonebot_bison.config_manager import add_sub_matcher, common_platform from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -114,7 +114,7 @@ async def test_abort_add_on_id(app: App, db_migration):
# 输入订阅类别阶段中止 # 输入订阅类别阶段中止
@pytest.mark.asyncio @pytest.mark.asyncio
@respx.mock @respx.mock
async def test_abort_add_on_cats(app: App, db_migration): async def test_abort_add_on_cats(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message from nonebot.adapters.onebot.v11.message import Message
from nonebot_bison.config_manager import add_sub_matcher, common_platform from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -191,7 +191,7 @@ async def test_abort_add_on_cats(app: App, db_migration):
# 输入标签阶段中止 # 输入标签阶段中止
@pytest.mark.asyncio @pytest.mark.asyncio
@respx.mock @respx.mock
async def test_abort_add_on_tag(app: App, db_migration): async def test_abort_add_on_tag(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message from nonebot.adapters.onebot.v11.message import Message
from nonebot_bison.config_manager import add_sub_matcher, common_platform from nonebot_bison.config_manager import add_sub_matcher, common_platform

View File

@ -1,5 +1,3 @@
from email import message
import pytest import pytest
import respx import respx
from httpx import Response from httpx import Response
@ -189,7 +187,7 @@ async def test_add_with_target_no_cat(app: App, init_scheduler):
from nonebot_bison.config import config from nonebot_bison.config import config
from nonebot_bison.config_manager import add_sub_matcher, common_platform from nonebot_bison.config_manager import add_sub_matcher, common_platform
from nonebot_bison.platform import platform_manager from nonebot_bison.platform import platform_manager
from nonebot_bison.platform.ncm_artist import NcmArtist from nonebot_bison.platform.ncm import NcmArtist
ncm_router = respx.get("https://music.163.com/api/artist/albums/32540734") ncm_router = respx.get("https://music.163.com/api/artist/albums/32540734")
ncm_router.mock(return_value=Response(200, json=get_json("ncm_siren.json"))) ncm_router.mock(return_value=Response(200, json=get_json("ncm_siren.json")))

View File

@ -1,6 +1,7 @@
import typing import typing
import pytest import pytest
from httpx import AsyncClient
from nonebug.app import App from nonebug.app import App
@ -29,7 +30,7 @@ VuePress 由两部分组成:第一部分是一个极简静态网站生成器
async def test_arknights(app: App): async def test_arknights(app: App):
from nonebot_bison.platform.arknights import Arknights from nonebot_bison.platform.arknights import Arknights
ak = Arknights() ak = Arknights(AsyncClient())
res = await ak.parse( res = await ak.parse(
{"webUrl": "https://ak.hycdn.cn/announce/IOS/announcement/854_1644580545.html"} {"webUrl": "https://ak.hycdn.cn/announce/IOS/announcement/854_1644580545.html"}
) )