inject http client to scheduler

This commit is contained in:
felinae98 2022-10-16 17:19:55 +08:00
parent c8a4644e40
commit 74b5074f04
No known key found for this signature in database
GPG Key ID: 00C8B010587FF610
30 changed files with 461 additions and 377 deletions

View File

@ -1,6 +1,7 @@
import nonebot
from nonebot.adapters.onebot.v11.bot import Bot
from ..apis import check_sub_target
from ..config import (
NoSuchSubscribeException,
NoSuchTargetException,
@ -8,7 +9,7 @@ from ..config import (
config,
)
from ..config.db_config import SubscribeDupException
from ..platform import check_sub_target, platform_manager
from ..platform import platform_manager
from ..types import Target as T_Target
from ..types import WeightConfig
from .jwt import pack_jwt

View File

@ -0,0 +1,12 @@
from .platform import platform_manager
from .scheduler import scheduler_dict
from .types import Target
async def check_sub_target(platform_name: str, target: Target):
platform = platform_manager[platform_name]
scheduler_conf_class = platform.scheduler
scheduler = scheduler_dict[scheduler_conf_class]
client = await scheduler.scheduler_config_obj.get_query_name_client()
return await platform_manager[platform_name].get_target_name(client, target)

View File

@ -15,9 +15,10 @@ from nonebot.permission import SUPERUSER
from nonebot.rule import to_me
from nonebot.typing import T_State
from .apis import check_sub_target
from .config import config
from .config.db_config import SubscribeDupException
from .platform import Platform, check_sub_target, platform_manager
from .platform import Platform, platform_manager
from .plugin_config import plugin_config
from .types import Category, Target, User
from .utils import parse_text
@ -117,9 +118,7 @@ def do_add_sub(add_sub: Type[Matcher]):
) + "请输入订阅用户的id\n查询id获取方法请回复:“查询”"
else:
state["id"] = "default"
state["name"] = await platform_manager[state["platform"]].get_target_name(
Target("")
)
state["name"] = await check_sub_target(state["platform"], Target(""))
async def parse_id(event: MessageEvent, state: T_State):
if not isinstance(state["id"], Message):

View File

@ -2,29 +2,24 @@ from collections import defaultdict
from importlib import import_module
from pathlib import Path
from pkgutil import iter_modules
from typing import DefaultDict, Type
from .platform import NoTargetGroup, Platform
from .platform import Platform, make_no_target_group
_package_dir = str(Path(__file__).resolve().parent)
for (_, module_name, _) in iter_modules([_package_dir]):
import_module(f"{__name__}.{module_name}")
async def check_sub_target(target_type, target):
return await platform_manager[target_type].get_target_name(target)
_platform_list = defaultdict(list)
_platform_list: DefaultDict[str, list[Type[Platform]]] = defaultdict(list)
for _platform in Platform.registry:
if not _platform.enabled:
continue
_platform_list[_platform.platform_name].append(_platform)
platform_manager: dict[str, Platform] = dict()
platform_manager: dict[str, Type[Platform]] = dict()
for name, platform_list in _platform_list.items():
if len(platform_list) == 1:
platform_manager[name] = platform_list[0]()
platform_manager[name] = platform_list[0]
else:
platform_manager[name] = NoTargetGroup(
[_platform() for _platform in platform_list]
)
platform_manager[name] = make_no_target_group(platform_list)

View File

@ -1,12 +1,12 @@
import json
from typing import Any
from typing import Any, Optional
from bs4 import BeautifulSoup as bs
from httpx import AsyncClient
from nonebot.plugin import require
from ..post import Post
from ..types import Category, RawPost, Target
from ..utils import http_client
from ..utils.scheduler_config import SchedulerConfig
from .platform import CategoryNotSupport, NewMessage, StatusChange
@ -29,15 +29,17 @@ class Arknights(NewMessage):
scheduler = ArknightsSchedConf
has_target = False
async def get_target_name(self, _: Target) -> str:
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]:
async with http_client() as client:
raw_data = await client.get(
"https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/announcement.meta.json"
)
return json.loads(raw_data.text)["announceList"]
raw_data = await self.client.get(
"https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/announcement.meta.json"
)
return json.loads(raw_data.text)["announceList"]
def get_id(self, post: RawPost) -> Any:
return post["announceId"]
@ -51,8 +53,7 @@ class Arknights(NewMessage):
async def parse(self, raw_post: RawPost) -> Post:
announce_url = raw_post["webUrl"]
text = ""
async with http_client() as client:
raw_html = await client.get(announce_url)
raw_html = await self.client.get(announce_url)
soup = bs(raw_html.text, "html.parser")
pics = []
if soup.find("div", class_="standerd-container"):
@ -101,17 +102,19 @@ class AkVersion(StatusChange):
scheduler = ArknightsSchedConf
has_target = False
async def get_target_name(self, _: Target) -> str:
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return "明日方舟游戏信息"
async def get_status(self, _):
async with http_client() as client:
res_ver = await client.get(
"https://ak-conf.hypergryph.com/config/prod/official/IOS/version"
)
res_preanounce = await client.get(
"https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json"
)
res_ver = await self.client.get(
"https://ak-conf.hypergryph.com/config/prod/official/IOS/version"
)
res_preanounce = await self.client.get(
"https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json"
)
res = res_ver.json()
res.update(res_preanounce.json())
return res
@ -156,13 +159,17 @@ class MonsterSiren(NewMessage):
scheduler = ArknightsSchedConf
has_target = False
async def get_target_name(self, _: Target) -> str:
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]:
async with http_client() as client:
raw_data = await client.get("https://monster-siren.hypergryph.com/api/news")
return raw_data.json()["data"]["list"]
raw_data = await self.client.get(
"https://monster-siren.hypergryph.com/api/news"
)
return raw_data.json()["data"]["list"]
def get_id(self, post: RawPost) -> Any:
return post["cid"]
@ -175,16 +182,15 @@ class MonsterSiren(NewMessage):
async def parse(self, raw_post: RawPost) -> Post:
url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}'
async with http_client() as client:
res = await client.get(
f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}'
)
raw_data = res.json()
content = raw_data["data"]["content"]
content = content.replace("</p>", "</p>\n")
soup = bs(content, "html.parser")
imgs = list(map(lambda x: x["src"], soup("img")))
text = f'{raw_post["title"]}\n{soup.text.strip()}'
res = await self.client.get(
f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}'
)
raw_data = res.json()
content = raw_data["data"]["content"]
content = content.replace("</p>", "</p>\n")
soup = bs(content, "html.parser")
imgs = list(map(lambda x: x["src"], soup("img")))
text = f'{raw_post["title"]}\n{soup.text.strip()}'
return Post(
"monster-siren",
text=text,
@ -207,15 +213,17 @@ class TerraHistoricusComic(NewMessage):
scheduler = ArknightsSchedConf
has_target = False
async def get_target_name(self, _: Target) -> str:
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return "明日方舟游戏信息"
async def get_sub_list(self, _) -> list[RawPost]:
async with http_client() as client:
raw_data = await client.get(
"https://terra-historicus.hypergryph.com/api/recentUpdate"
)
return raw_data.json()["data"]
raw_data = await self.client.get(
"https://terra-historicus.hypergryph.com/api/recentUpdate"
)
return raw_data.json()["data"]
def get_id(self, post: RawPost) -> Any:
return f'{post["comicCid"]}/{post["episodeCid"]}'

View File

@ -5,6 +5,7 @@ from datetime import datetime, timedelta
from typing import Any, Callable, Optional
import httpx
from httpx import AsyncClient
from nonebot.log import logger
from ..post import Post
@ -20,35 +21,36 @@ class BilibiliSchedConf(SchedulerConfig):
schedule_type = "interval"
schedule_setting = {"seconds": 10}
from .platform import CategoryNotSupport, NewMessage, StatusChange
class _BilibiliClient:
_http_client: httpx.AsyncClient
_client_refresh_time: Optional[datetime]
_client_refresh_time: datetime
cookie_expire_time = timedelta(hours=5)
def __init__(self):
self._client_refresh_time = datetime(
year=2000, month=1, day=1
) # an expired time
super().__init__()
async def _init_session(self):
self._http_client = httpx.AsyncClient(**http_args)
res = await self._http_client.get("https://www.bilibili.com/")
res = await self.default_http_client.get("https://www.bilibili.com/")
if res.status_code != 200:
logger.warning("unable to refresh temp cookie")
else:
self._client_refresh_time = datetime.now()
async def _refresh_client(self):
if (
getattr(self, "_client_refresh_time", None) is None
or datetime.now() - self._client_refresh_time
> self.cookie_expire_time # type:ignore
or self._http_client is None
):
if datetime.now() - self._client_refresh_time > self.cookie_expire_time:
await self._init_session()
async def get_client(self, target: Target) -> AsyncClient:
await self._refresh_client()
return await super().get_client(target)
class Bilibili(_BilibiliClient, NewMessage):
async def get_query_name_client(self) -> AsyncClient:
await self._refresh_client()
return await super().get_query_name_client()
class Bilibili(NewMessage):
categories = {
1: "一般动态",
@ -67,17 +69,11 @@ class Bilibili(_BilibiliClient, NewMessage):
has_target = True
parse_target_promot = "请输入用户主页的链接"
def ensure_client(fun: Callable): # type:ignore
@functools.wraps(fun)
async def wrapped(self, *args, **kwargs):
await self._refresh_client()
return await fun(self, *args, **kwargs)
return wrapped
@ensure_client
async def get_target_name(self, target: Target) -> Optional[str]:
res = await self._http_client.get(
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
res = await client.get(
"https://api.bilibili.com/x/space/acc/info", params={"mid": target}
)
res_data = json.loads(res.text)
@ -85,18 +81,18 @@ class Bilibili(_BilibiliClient, NewMessage):
return None
return res_data["data"]["name"]
async def parse_target(self, target_text: str) -> Target:
@classmethod
async def parse_target(cls, target_text: str) -> Target:
if re.match(r"\d+", target_text):
return Target(target_text)
elif m := re.match(r"(?:https?://)?space\.bilibili\.com/(\d+)", target_text):
return Target(m.group(1))
else:
raise self.ParseTargetException()
raise cls.ParseTargetException()
@ensure_client
async def get_sub_list(self, target: Target) -> list[RawPost]:
params = {"host_uid": target, "offset": 0, "need_top": 0}
res = await self._http_client.get(
res = await self.client.get(
"https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history",
params=params,
timeout=4.0,
@ -202,7 +198,7 @@ class Bilibili(_BilibiliClient, NewMessage):
return Post("bilibili", text=text, url=url, pics=pic, target_name=target_name)
class Bilibililive(_BilibiliClient, StatusChange):
class Bilibililive(StatusChange):
# Author : Sichongzou
# Date : 2022-5-18 8:54
# Description : bilibili开播提醒
@ -216,17 +212,11 @@ class Bilibililive(_BilibiliClient, StatusChange):
name = "Bilibili直播"
has_target = True
def ensure_client(fun: Callable): # type:ignore
@functools.wraps(fun)
async def wrapped(self, *args, **kwargs):
await self._refresh_client()
return await fun(self, *args, **kwargs)
return wrapped
@ensure_client
async def get_target_name(self, target: Target) -> Optional[str]:
res = await self._http_client.get(
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
res = await client.get(
"https://api.bilibili.com/x/space/acc/info", params={"mid": target}
)
res_data = json.loads(res.text)
@ -234,10 +224,9 @@ class Bilibililive(_BilibiliClient, StatusChange):
return None
return res_data["data"]["name"]
@ensure_client
async def get_status(self, target: Target):
params = {"mid": target}
res = await self._http_client.get(
res = await self.client.get(
"https://api.bilibili.com/x/space/acc/info",
params=params,
timeout=4.0,
@ -279,7 +268,7 @@ class Bilibililive(_BilibiliClient, StatusChange):
)
class BilibiliBangumi(_BilibiliClient, StatusChange):
class BilibiliBangumi(StatusChange):
categories = {}
platform_name = "bilibili-bangumi"
@ -293,23 +282,18 @@ class BilibiliBangumi(_BilibiliClient, StatusChange):
_url = "https://api.bilibili.com/pgc/review/user"
def ensure_client(fun: Callable): # type:ignore
@functools.wraps(fun)
async def wrapped(self, *args, **kwargs):
await self._refresh_client()
return await fun(self, *args, **kwargs)
return wrapped
@ensure_client
async def get_target_name(self, target: Target) -> Optional[str]:
res = await self._http_client.get(self._url, params={"media_id": target})
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
res = await client.get(cls._url, params={"media_id": target})
res_data = res.json()
if res_data["code"]:
return None
return res_data["result"]["media"]["title"]
async def parse_target(self, target_string: str) -> Target:
@classmethod
async def parse_target(cls, target_string: str) -> Target:
if re.match(r"\d+", target_string):
return Target(target_string)
elif m := re.match(r"md(\d+)", target_string):
@ -318,11 +302,10 @@ class BilibiliBangumi(_BilibiliClient, StatusChange):
r"(?:https?://)?www\.bilibili\.com/bangumi/media/md(\d+)/", target_string
):
return Target(m.group(1))
raise self.ParseTargetException()
raise cls.ParseTargetException()
@ensure_client
async def get_status(self, target: Target):
res = await self._http_client.get(
res = await self.client.get(
self._url,
params={"media_id": target},
timeout=4.0,
@ -343,9 +326,8 @@ class BilibiliBangumi(_BilibiliClient, StatusChange):
else:
return []
@ensure_client
async def parse(self, raw_post: RawPost) -> Post:
detail_res = await self._http_client.get(
detail_res = await self.client.get(
f'http://api.bilibili.com/pgc/view/web/season?season_id={raw_post["season_id"]}'
)
detail_dict = detail_res.json()

View File

@ -1,4 +1,6 @@
from typing import Any
from typing import Any, Optional
from httpx import AsyncClient
from ..post import Post
from ..types import RawPost, Target
@ -18,7 +20,10 @@ class FF14(NewMessage):
scheduler = scheduler("interval", {"seconds": 60})
has_target = False
async def get_target_name(self, _: Target) -> str:
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return "最终幻想XIV官方公告"
async def get_sub_list(self, _) -> list[RawPost]:

View File

@ -1,9 +1,10 @@
import re
import time
from typing import Literal
from typing import Literal, Optional
import httpx
from bs4 import BeautifulSoup, NavigableString, Tag
from httpx import AsyncClient
from ..post import Post
from ..types import Category, RawPost, Target
@ -42,8 +43,11 @@ class McbbsNews(NewMessage):
scheduler = scheduler("interval", {"hours": 1})
has_target = False
async def get_target_name(self, _: Target) -> str:
return self.name
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
return cls.name
async def get_sub_list(self, _: Target) -> list[RawPost]:
url = "https://www.mcbbs.net/forum-news-1.html"

View File

@ -0,0 +1,146 @@
import re
from typing import Any, Optional
from httpx import AsyncClient
from ..post import Post
from ..types import RawPost, Target
from ..utils import SchedulerConfig, http_client
from .platform import NewMessage
class NcmSchedConf(SchedulerConfig):
name = "music.163.com"
schedule_type = "interval"
schedule_setting = {"minutes": 1}
class NcmArtist(NewMessage):
categories = {}
platform_name = "ncm-artist"
enable_tag = False
enabled = True
is_common = True
scheduler = NcmSchedConf
name = "网易云-歌手"
has_target = True
parse_target_promot = "请输入歌手主页包含数字ID的链接"
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async with http_client() as client:
res = await client.get(
"https://music.163.com/api/artist/albums/{}".format(target),
headers={"Referer": "https://music.163.com/"},
)
res_data = res.json()
if res_data["code"] != 200:
return
return res_data["artist"]["name"]
@classmethod
async def parse_target(cls, target_text: str) -> Target:
if re.match(r"^\d+$", target_text):
return Target(target_text)
elif match := re.match(
r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text
):
return Target(match.group(1))
else:
raise cls.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]:
async with http_client() as client:
res = await client.get(
"https://music.163.com/api/artist/albums/{}".format(target),
headers={"Referer": "https://music.163.com/"},
)
res_data = res.json()
if res_data["code"] != 200:
return []
else:
return res_data["hotAlbums"]
def get_id(self, post: RawPost) -> Any:
return post["id"]
def get_date(self, post: RawPost) -> int:
return post["publishTime"] // 1000
async def parse(self, raw_post: RawPost) -> Post:
text = "新专辑发布:{}".format(raw_post["name"])
target_name = raw_post["artist"]["name"]
pics = [raw_post["picUrl"]]
url = "https://music.163.com/#/album?id={}".format(raw_post["id"])
return Post(
"ncm-artist", text=text, url=url, pics=pics, target_name=target_name
)
class NcmRadio(NewMessage):
categories = {}
platform_name = "ncm-radio"
enable_tag = False
enabled = True
is_common = False
scheduler = NcmSchedConf
name = "网易云-电台"
has_target = True
parse_target_promot = "请输入主播电台主页包含数字ID的链接"
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async with http_client() as client:
res = await client.post(
"http://music.163.com/api/dj/program/byradio",
headers={"Referer": "https://music.163.com/"},
data={"radioId": target, "limit": 1000, "offset": 0},
)
res_data = res.json()
if res_data["code"] != 200 or res_data["programs"] == 0:
return
return res_data["programs"][0]["radio"]["name"]
@classmethod
async def parse_target(cls, target_text: str) -> Target:
if re.match(r"^\d+$", target_text):
return Target(target_text)
elif match := re.match(
r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text
):
return Target(match.group(1))
else:
raise cls.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]:
async with http_client() as client:
res = await client.post(
"http://music.163.com/api/dj/program/byradio",
headers={"Referer": "https://music.163.com/"},
data={"radioId": target, "limit": 1000, "offset": 0},
)
res_data = res.json()
if res_data["code"] != 200:
return []
else:
return res_data["programs"]
def get_id(self, post: RawPost) -> Any:
return post["id"]
def get_date(self, post: RawPost) -> int:
return post["createTime"] // 1000
async def parse(self, raw_post: RawPost) -> Post:
text = "网易云电台更新:{}".format(raw_post["name"])
target_name = raw_post["radio"]["name"]
pics = [raw_post["coverUrl"]]
url = "https://music.163.com/#/program/{}".format(raw_post["id"])
return Post("ncm-radio", text=text, url=url, pics=pics, target_name=target_name)

View File

@ -1,75 +0,0 @@
import re
from typing import Any, Optional
from ..post import Post
from ..types import RawPost, Target
from ..utils import SchedulerConfig, http_client
from .platform import NewMessage
class NcmSchedConf(SchedulerConfig):
name = "music.163.com"
schedule_type = "interval"
schedule_setting = {"minutes": 1}
class NcmArtist(NewMessage):
categories = {}
platform_name = "ncm-artist"
enable_tag = False
enabled = True
is_common = True
scheduler = NcmSchedConf
name = "网易云-歌手"
has_target = True
parse_target_promot = "请输入歌手主页包含数字ID的链接"
async def get_target_name(self, target: Target) -> Optional[str]:
async with http_client() as client:
res = await client.get(
"https://music.163.com/api/artist/albums/{}".format(target),
headers={"Referer": "https://music.163.com/"},
)
res_data = res.json()
if res_data["code"] != 200:
return
return res_data["artist"]["name"]
async def parse_target(self, target_text: str) -> Target:
if re.match(r"^\d+$", target_text):
return Target(target_text)
elif match := re.match(
r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text
):
return Target(match.group(1))
else:
raise self.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]:
async with http_client() as client:
res = await client.get(
"https://music.163.com/api/artist/albums/{}".format(target),
headers={"Referer": "https://music.163.com/"},
)
res_data = res.json()
if res_data["code"] != 200:
return []
else:
return res_data["hotAlbums"]
def get_id(self, post: RawPost) -> Any:
return post["id"]
def get_date(self, post: RawPost) -> int:
return post["publishTime"] // 1000
async def parse(self, raw_post: RawPost) -> Post:
text = "新专辑发布:{}".format(raw_post["name"])
target_name = raw_post["artist"]["name"]
pics = [raw_post["picUrl"]]
url = "https://music.163.com/#/album?id={}".format(raw_post["id"])
return Post(
"ncm-artist", text=text, url=url, pics=pics, target_name=target_name
)

View File

@ -1,69 +0,0 @@
import re
from typing import Any, Optional
from ..post import Post
from ..types import RawPost, Target
from ..utils import http_client
from .ncm_artist import NcmSchedConf
from .platform import NewMessage
class NcmRadio(NewMessage):
categories = {}
platform_name = "ncm-radio"
enable_tag = False
enabled = True
is_common = False
scheduler = NcmSchedConf
name = "网易云-电台"
has_target = True
parse_target_promot = "请输入主播电台主页包含数字ID的链接"
async def get_target_name(self, target: Target) -> Optional[str]:
async with http_client() as client:
res = await client.post(
"http://music.163.com/api/dj/program/byradio",
headers={"Referer": "https://music.163.com/"},
data={"radioId": target, "limit": 1000, "offset": 0},
)
res_data = res.json()
if res_data["code"] != 200 or res_data["programs"] == 0:
return
return res_data["programs"][0]["radio"]["name"]
async def parse_target(self, target_text: str) -> Target:
if re.match(r"^\d+$", target_text):
return Target(target_text)
elif match := re.match(
r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text
):
return Target(match.group(1))
else:
raise self.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]:
async with http_client() as client:
res = await client.post(
"http://music.163.com/api/dj/program/byradio",
headers={"Referer": "https://music.163.com/"},
data={"radioId": target, "limit": 1000, "offset": 0},
)
res_data = res.json()
if res_data["code"] != 200:
return []
else:
return res_data["programs"]
def get_id(self, post: RawPost) -> Any:
return post["id"]
def get_date(self, post: RawPost) -> int:
return post["createTime"] // 1000
async def parse(self, raw_post: RawPost) -> Post:
text = "网易云电台更新:{}".format(raw_post["name"])
target_name = raw_post["radio"]["name"]
pics = [raw_post["coverUrl"]]
url = "https://music.163.com/#/program/{}".format(raw_post["id"])
return Post("ncm-radio", text=text, url=url, pics=pics, target_name=target_name)

View File

@ -1,12 +1,14 @@
import json
import ssl
import time
import typing
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Collection, Literal, Optional, Type
from typing import Any, Collection, Optional, Type
import httpx
from httpx import AsyncClient
from nonebot.log import logger
from ..plugin_config import plugin_config
@ -34,11 +36,23 @@ class RegistryMeta(type):
super().__init__(name, bases, namespace, **kwargs)
class RegistryABCMeta(RegistryMeta, ABC):
class PlatformMeta(RegistryMeta):
categories: dict[Category, str]
def __init__(cls, name, bases, namespace, **kwargs):
cls.reverse_category = {}
if hasattr(cls, "categories") and cls.categories:
for key, val in cls.categories.items():
cls.reverse_category[val] = key
super().__init__(name, bases, namespace, **kwargs)
class PlatformABCMeta(PlatformMeta, ABC):
...
class Platform(metaclass=RegistryABCMeta, base=True):
class Platform(metaclass=PlatformABCMeta, base=True):
scheduler: Type[SchedulerConfig]
is_common: bool
@ -50,9 +64,15 @@ class Platform(metaclass=RegistryABCMeta, base=True):
store: dict[Target, Any]
platform_name: str
parse_target_promot: Optional[str] = None
registry: list[Type["Platform"]]
client: AsyncClient
reverse_category: dict[str, Category]
@classmethod
@abstractmethod
async def get_target_name(self, target: Target) -> Optional[str]:
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
...
@abstractmethod
@ -88,17 +108,16 @@ class Platform(metaclass=RegistryABCMeta, base=True):
"actually function called"
return await self.parse(raw_post)
def __init__(self):
def __init__(self, client: AsyncClient):
super().__init__()
self.reverse_category = {}
for key, val in self.categories.items():
self.reverse_category[val] = key
self.store = dict()
self.client = client
class ParseTargetException(Exception):
pass
async def parse_target(self, target_string: str) -> Target:
@classmethod
async def parse_target(cls, target_string: str) -> Target:
return Target(target_string)
@abstractmethod
@ -188,8 +207,8 @@ class Platform(metaclass=RegistryABCMeta, base=True):
class MessageProcess(Platform, abstract=True):
"General message process fetch, parse, filter progress"
def __init__(self):
super().__init__()
def __init__(self, client: AsyncClient):
super().__init__(client)
self.parse_cache: dict[Any, Post] = dict()
@abstractmethod
@ -362,55 +381,82 @@ class SimplePost(MessageProcess, abstract=True):
return res
class NoTargetGroup(Platform, abstract=True):
enable_tag = False
def make_no_target_group(platform_list: list[Type[Platform]]) -> Type[Platform]:
if typing.TYPE_CHECKING:
class NoTargetGroup(Platform, abstract=True):
platform_list: list[Type[Platform]]
platform_obj_list: list[Platform]
DUMMY_STR = "_DUMMY"
enabled = True
has_target = False
def __init__(self, platform_list: list[Platform]):
self.platform_list = platform_list
self.platform_name = platform_list[0].platform_name
name = self.DUMMY_STR
self.categories = {}
categories_keys = set()
self.scheduler = platform_list[0].scheduler
for platform in platform_list:
if platform.has_target:
raise RuntimeError(
"Platform {} should have no target".format(platform.name)
)
if name == self.DUMMY_STR:
name = platform.name
elif name != platform.name:
raise RuntimeError(
"Platform name for {} not fit".format(self.platform_name)
)
platform_category_key_set = set(platform.categories.keys())
if platform_category_key_set & categories_keys:
raise RuntimeError(
"Platform categories for {} duplicate".format(self.platform_name)
)
categories_keys |= platform_category_key_set
self.categories.update(platform.categories)
if platform.scheduler != self.scheduler:
raise RuntimeError(
"Platform scheduler for {} not fit".format(self.platform_name)
)
self.name = name
self.is_common = platform_list[0].is_common
super().__init__()
platform_name = platform_list[0].platform_name
name = DUMMY_STR
categories_keys = set()
categories = {}
scheduler = platform_list[0].scheduler
def __str__(self):
for platform in platform_list:
if platform.has_target:
raise RuntimeError(
"Platform {} should have no target".format(platform.name)
)
if name == DUMMY_STR:
name = platform.name
elif name != platform.name:
raise RuntimeError("Platform name for {} not fit".format(platform_name))
platform_category_key_set = set(platform.categories.keys())
if platform_category_key_set & categories_keys:
raise RuntimeError(
"Platform categories for {} duplicate".format(platform_name)
)
categories_keys |= platform_category_key_set
categories.update(platform.categories)
if platform.scheduler != scheduler:
raise RuntimeError(
"Platform scheduler for {} not fit".format(platform_name)
)
def __init__(self: "NoTargetGroup", client: AsyncClient):
Platform.__init__(self, client)
self.platform_obj_list = []
for platform_class in self.platform_list:
self.platform_obj_list.append(platform_class(client))
def __str__(self: "NoTargetGroup") -> str:
return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]"
async def get_target_name(self, _):
return await self.platform_list[0].get_target_name(_)
@classmethod
async def get_target_name(cls, client: AsyncClient, target: Target):
return await platform_list[0].get_target_name(client, target)
async def fetch_new_post(self, target, users):
async def fetch_new_post(
self: "NoTargetGroup", target: Target, users: list[UserSubInfo]
):
res = defaultdict(list)
for platform in self.platform_list:
for platform in self.platform_obj_list:
platform_res = await platform.fetch_new_post(target=target, users=users)
for user, posts in platform_res:
res[user].extend(posts)
return [[key, val] for key, val in res.items()]
return type(
"NoTargetGroup",
(Platform,),
{
"platform_list": platform_list,
"platform_name": platform_list[0].platform_name,
"name": name,
"categories": categories,
"scheduler": scheduler,
"is_common": platform_list[0].is_common,
"enabled": True,
"has_target": False,
"enable_tag": False,
"__init__": __init__,
"get_target_name": get_target_name,
"fetch_new_post": fetch_new_post,
},
abstract=True,
)

View File

@ -3,6 +3,7 @@ from typing import Any, Optional
import feedparser
from bs4 import BeautifulSoup as bs
from httpx import AsyncClient
from ..post import Post
from ..types import RawPost, Target
@ -21,7 +22,10 @@ class Rss(NewMessage):
scheduler = scheduler("interval", {"seconds": 30})
has_target = True
async def get_target_name(self, target: Target) -> Optional[str]:
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async with http_client() as client:
res = await client.get(target, timeout=10.0)
feed = feedparser.parse(res.text)

View File

@ -5,6 +5,7 @@ from datetime import datetime
from typing import Any, Optional
from bs4 import BeautifulSoup as bs
from httpx import AsyncClient
from nonebot.log import logger
from ..post import Post
@ -36,7 +37,10 @@ class Weibo(NewMessage):
has_target = True
parse_target_promot = "请输入用户主页包含数字UID的链接"
async def get_target_name(self, target: Target) -> Optional[str]:
@classmethod
async def get_target_name(
cls, client: AsyncClient, target: Target
) -> Optional[str]:
async with http_client() as client:
param = {"containerid": "100505" + target}
res = await client.get(
@ -48,14 +52,15 @@ class Weibo(NewMessage):
else:
return None
async def parse_target(self, target_text: str) -> Target:
@classmethod
async def parse_target(cls, target_text: str) -> Target:
if re.match(r"\d+", target_text):
return Target(target_text)
elif match := re.match(r"(?:https?://)?weibo\.com/u/(\d+)", target_text):
# 都2202年了应该不会有http了吧不过还是防一手
return Target(match.group(1))
else:
raise self.ParseTargetException()
raise cls.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]:
async with http_client() as client:

View File

@ -7,7 +7,6 @@ from nonebot.log import logger
from ..config import config
from ..platform import platform_manager
from ..platform.platform import Platform
from ..send import send_msgs
from ..types import Target
from ..utils import SchedulerConfig
@ -36,6 +35,7 @@ class Scheduler:
logger.error(f"scheduler config [{self.name}] not found, exiting")
raise RuntimeError(f"{self.name} not found")
self.scheduler_config = scheduler_config
self.scheduler_config_obj = self.scheduler_config()
self.schedulable_list = []
for platform_name, target in schedulables:
self.schedulable_list.append(
@ -86,7 +86,10 @@ class Scheduler:
send_userinfo_list = await config.get_platform_target_subscribers(
schedulable.platform_name, schedulable.target
)
to_send = await platform_manager[schedulable.platform_name].do_fetch_new_post(
platform_obj = platform_manager[schedulable.platform_name](
await self.scheduler_config_obj.get_client(schedulable.target)
)
to_send = await platform_obj.do_fetch_new_post(
schedulable.target, send_userinfo_list
)
if not to_send:

View File

@ -1,5 +1,10 @@
from typing import Literal, Type
from httpx import AsyncClient
from ..types import Target
from .http import http_client
class SchedulerConfig:
@ -10,6 +15,15 @@ class SchedulerConfig:
def __str__(self):
return f"[{self.name}]-{self.name}-{self.schedule_setting}"
def __init__(self):
self.default_http_client = http_client()
async def get_client(self, target: Target) -> AsyncClient:
return self.default_http_client
async def get_query_name_client(self) -> AsyncClient:
return self.default_http_client
def scheduler(
schedule_type: Literal["date", "interval", "cron"], schedule_setting: dict

View File

@ -1,6 +1,6 @@
import pytest
import respx
from httpx import Response
from httpx import AsyncClient, Response
from nonebug.app import App
from .utils import get_file, get_json
@ -10,7 +10,7 @@ from .utils import get_file, get_json
def arknights(app: App):
from nonebot_bison.platform import platform_manager
return platform_manager["arknights"]
return platform_manager["arknights"](AsyncClient())
@pytest.fixture(scope="module")

View File

@ -3,7 +3,7 @@ from datetime import datetime
import pytest
import respx
from httpx import Response
from httpx import AsyncClient, Response
from nonebug.app import App
from pytz import timezone
@ -23,7 +23,7 @@ if typing.TYPE_CHECKING:
def bilibili(app: App):
from nonebot_bison.platform import platform_manager
return platform_manager["bilibili"]
return platform_manager["bilibili"](AsyncClient())
@pytest.mark.asyncio

View File

@ -2,7 +2,7 @@ import typing
import pytest
import respx
from httpx import Response
from httpx import AsyncClient, Response
from nonebug.app import App
from .utils import get_json
@ -15,7 +15,7 @@ if typing.TYPE_CHECKING:
def bili_bangumi(app: App):
from nonebot_bison.platform import platform_manager
return platform_manager["bilibili-bangumi"]
return platform_manager["bilibili-bangumi"](AsyncClient())
@pytest.mark.asyncio

View File

@ -1,6 +1,6 @@
import pytest
import respx
from httpx import Response
from httpx import AsyncClient, Response
from nonebug.app import App
from .utils import get_json
@ -10,7 +10,7 @@ from .utils import get_json
def bili_live(app: App):
from nonebot_bison.platform import platform_manager
return platform_manager["bilibili-live"]
return platform_manager["bilibili-live"](AsyncClient())
@pytest.mark.asyncio

View File

@ -1,6 +1,6 @@
import pytest
import respx
from httpx import Response
from httpx import AsyncClient, Response
from nonebug.app import App
from .utils import get_json
@ -10,7 +10,7 @@ from .utils import get_json
def ff14(app: App):
from nonebot_bison.platform import platform_manager
return platform_manager["ff14"]
return platform_manager["ff14"](AsyncClient())
@pytest.fixture(scope="module")

View File

@ -1,6 +1,6 @@
import pytest
import respx
from httpx import Response
from httpx import AsyncClient, Response
from nonebug.app import App
from .utils import get_file, get_json
@ -10,7 +10,7 @@ from .utils import get_file, get_json
def mcbbsnews(app: App):
from nonebot_bison.platform import platform_manager
return platform_manager["mcbbsnews"]
return platform_manager["mcbbsnews"](AsyncClient())
@pytest.fixture(scope="module")

View File

@ -3,20 +3,20 @@ import typing
import pytest
import respx
from httpx import Response
from httpx import AsyncClient, Response
from nonebug.app import App
from .utils import get_json
if typing.TYPE_CHECKING:
from nonebot_bison.platform.ncm_artist import NcmArtist
from nonebot_bison.platform.ncm import NcmArtist
@pytest.fixture
def ncm_artist(app: App):
from nonebot_bison.platform import platform_manager
return platform_manager["ncm-artist"]
return platform_manager["ncm-artist"](AsyncClient())
@pytest.fixture(scope="module")

View File

@ -3,20 +3,20 @@ import typing
import pytest
import respx
from httpx import Response
from httpx import AsyncClient, Response
from nonebug.app import App
from .utils import get_json
if typing.TYPE_CHECKING:
from nonebot_bison.platform.ncm_radio import NcmRadio
from nonebot_bison.platform.ncm import NcmRadio
@pytest.fixture
def ncm_radio(app: App):
from nonebot_bison.platform import platform_manager
return platform_manager["ncm-radio"]
return platform_manager["ncm-radio"](AsyncClient())
@pytest.fixture(scope="module")

View File

@ -53,12 +53,12 @@ def mock_platform_without_cats_tags(app: App):
categories = {}
has_target = True
def __init__(self):
def __init__(self, client):
self.sub_index = 0
super().__init__()
super().__init__(client)
@staticmethod
async def get_target_name(_: "Target"):
@classmethod
async def get_target_name(cls, client, _: "Target"):
return "MockPlatform"
def get_id(self, post: "RawPost") -> Any:
@ -82,7 +82,7 @@ def mock_platform_without_cats_tags(app: App):
else:
return raw_post_list_2
return MockPlatform()
return MockPlatform(None)
@pytest.fixture
@ -112,9 +112,9 @@ def mock_platform(app: App):
Category(2): "视频",
}
def __init__(self):
def __init__(self, client):
self.sub_index = 0
super().__init__()
super().__init__(client)
@staticmethod
async def get_target_name(_: "Target"):
@ -147,7 +147,7 @@ def mock_platform(app: App):
else:
return raw_post_list_2
return MockPlatform()
return MockPlatform(None)
@pytest.fixture
@ -180,9 +180,9 @@ def mock_platform_no_target(app: App, mock_scheduler_conf):
has_target = False
categories = {Category(1): "转发", Category(2): "视频", Category(3): "不支持"}
def __init__(self):
def __init__(self, client):
self.sub_index = 0
super().__init__()
super().__init__(client)
@staticmethod
async def get_target_name(_: "Target"):
@ -217,7 +217,7 @@ def mock_platform_no_target(app: App, mock_scheduler_conf):
else:
return raw_post_list_2
return MockPlatform()
return MockPlatform
@pytest.fixture
@ -241,12 +241,12 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf):
Category(5): "leixing5",
}
def __init__(self):
def __init__(self, client):
self.sub_index = 0
super().__init__()
super().__init__(client)
@staticmethod
async def get_target_name(_: "Target"):
@classmethod
async def get_target_name(cls, client, _: "Target"):
return "MockPlatform"
def get_id(self, post: "RawPost") -> Any:
@ -284,7 +284,7 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf):
else:
return list_2
return MockPlatform()
return MockPlatform
@pytest.fixture
@ -308,9 +308,9 @@ def mock_status_change(app: App):
Category(2): "视频",
}
def __init__(self):
def __init__(self, client):
self.sub_index = 0
super().__init__()
super().__init__(client)
async def get_status(self, _: "Target"):
if self.sub_index == 0:
@ -335,7 +335,7 @@ def mock_status_change(app: App):
def get_category(self, raw_post):
return raw_post["cat"]
return MockPlatform()
return MockPlatform(None)
@pytest.mark.asyncio
@ -388,6 +388,7 @@ async def test_new_message_target(mock_platform, user_info_factory):
@pytest.mark.asyncio
async def test_new_message_no_target(mock_platform_no_target, user_info_factory):
mock_platform_no_target = mock_platform_no_target(None)
res1 = await mock_platform_no_target.fetch_new_post(
"dummy", [user_info_factory([1, 2], [])]
)
@ -457,11 +458,14 @@ async def test_group(
user_info_factory,
):
from nonebot_bison.platform.platform import NoTargetGroup
from nonebot_bison.platform.platform import make_no_target_group
from nonebot_bison.post import Post
from nonebot_bison.types import Category, RawPost, Tag, Target
group_platform = NoTargetGroup([mock_platform_no_target, mock_platform_no_target_2])
group_platform_class = make_no_target_group(
[mock_platform_no_target, mock_platform_no_target_2]
)
group_platform = group_platform_class(None)
res1 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])])
assert len(res1) == 0
res2 = await group_platform.fetch_new_post("dummy", [user_info_factory([1, 4], [])])

View File

@ -1,4 +1,5 @@
import pytest
from httpx import AsyncClient
from nonebug.app import App
from .utils import get_json
@ -14,7 +15,7 @@ def test_cases():
async def test_filter_user_custom_tag(app: App, test_cases):
from nonebot_bison.platform import platform_manager
bilibili = platform_manager["bilibili"]
bilibili = platform_manager["bilibili"](AsyncClient())
for case in test_cases:
res = bilibili.is_banned_post(**case["case"])
assert res == case["result"]
@ -25,7 +26,7 @@ async def test_filter_user_custom_tag(app: App, test_cases):
async def test_tag_separator(app: App):
from nonebot_bison.platform import platform_manager
bilibili = platform_manager["bilibili"]
bilibili = platform_manager["bilibili"](AsyncClient())
tags = ["~111", "222", "333", "~444", "555"]
res = bilibili.tag_separator(tags)
assert res[0] == ["222", "333", "555"]

View File

@ -4,7 +4,7 @@ from datetime import datetime
import feedparser
import pytest
import respx
from httpx import Response
from httpx import AsyncClient, Response
from nonebug.app import App
from pytz import timezone
@ -18,7 +18,7 @@ if typing.TYPE_CHECKING:
def weibo(app: App):
from nonebot_bison.platform import platform_manager
return platform_manager["weibo"]
return platform_manager["weibo"](AsyncClient())
@pytest.fixture(scope="module")
@ -35,7 +35,7 @@ async def test_get_name(weibo):
profile_router.mock(
return_value=Response(200, json=get_json("weibo_ak_profile.json"))
)
name = await weibo.get_target_name("6279793937")
name = await weibo.get_target_name(AsyncClient(), "6279793937")
assert name == "明日方舟Arknights"

View File

@ -10,7 +10,7 @@ from .utils import BotReply, fake_admin_user, fake_group_message_event
# 选择platform阶段中止
@pytest.mark.asyncio
@respx.mock
async def test_abort_add_on_platform(app: App, db_migration):
async def test_abort_add_on_platform(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message
from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -57,7 +57,7 @@ async def test_abort_add_on_platform(app: App, db_migration):
# 输入id阶段中止
@pytest.mark.asyncio
@respx.mock
async def test_abort_add_on_id(app: App, db_migration):
async def test_abort_add_on_id(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message
from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -114,7 +114,7 @@ async def test_abort_add_on_id(app: App, db_migration):
# 输入订阅类别阶段中止
@pytest.mark.asyncio
@respx.mock
async def test_abort_add_on_cats(app: App, db_migration):
async def test_abort_add_on_cats(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message
from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -191,7 +191,7 @@ async def test_abort_add_on_cats(app: App, db_migration):
# 输入标签阶段中止
@pytest.mark.asyncio
@respx.mock
async def test_abort_add_on_tag(app: App, db_migration):
async def test_abort_add_on_tag(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message
from nonebot_bison.config_manager import add_sub_matcher, common_platform

View File

@ -1,5 +1,3 @@
from email import message
import pytest
import respx
from httpx import Response
@ -189,7 +187,7 @@ async def test_add_with_target_no_cat(app: App, init_scheduler):
from nonebot_bison.config import config
from nonebot_bison.config_manager import add_sub_matcher, common_platform
from nonebot_bison.platform import platform_manager
from nonebot_bison.platform.ncm_artist import NcmArtist
from nonebot_bison.platform.ncm import NcmArtist
ncm_router = respx.get("https://music.163.com/api/artist/albums/32540734")
ncm_router.mock(return_value=Response(200, json=get_json("ncm_siren.json")))

View File

@ -1,6 +1,7 @@
import typing
import pytest
from httpx import AsyncClient
from nonebug.app import App
@ -29,7 +30,7 @@ VuePress 由两部分组成:第一部分是一个极简静态网站生成器
async def test_arknights(app: App):
from nonebot_bison.platform.arknights import Arknights
ak = Arknights()
ak = Arknights(AsyncClient())
res = await ak.parse(
{"webUrl": "https://ak.hycdn.cn/announce/IOS/announcement/854_1644580545.html"}
)