mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2026-05-09 18:27:56 +08:00
✨ 提供批量 api 接口 (#290)
* 🚧 add batch api * 🚧 support batch in scheduler * ✨ use batch api in bilibili-live ✅ patch platform_manager directly ✨ use batch api in bilibili-live ✅ patch platform_manager directly * ♻️ refactor var name * 🐛 fix test * 🐛 fix scheduler * 🐛 fix test
This commit is contained in:
@@ -201,6 +201,7 @@ class Bilibililive(StatusChange):
|
||||
scheduler = BilibiliSchedConf
|
||||
name = "Bilibili直播"
|
||||
has_target = True
|
||||
use_batch = True
|
||||
|
||||
@unique
|
||||
class LiveStatus(Enum):
|
||||
@@ -281,12 +282,11 @@ class Bilibililive(StatusChange):
|
||||
keyframe="",
|
||||
)
|
||||
|
||||
async def get_status(self, target: Target) -> Info:
|
||||
params = {"uids[]": target}
|
||||
async def batch_get_status(self, targets: list[Target]) -> list[Info]:
|
||||
# https://github.com/SocialSisterYi/bilibili-API-collect/blob/master/docs/live/info.md#批量查询直播间状态
|
||||
res = await self.client.get(
|
||||
"https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids",
|
||||
params=params,
|
||||
params={"uids[]": targets},
|
||||
timeout=4.0,
|
||||
)
|
||||
res_dict = res.json()
|
||||
@@ -294,11 +294,14 @@ class Bilibililive(StatusChange):
|
||||
if res_dict["code"] != 0:
|
||||
raise self.FetchError()
|
||||
|
||||
data = res_dict.get("data")
|
||||
if not data:
|
||||
return self._gen_empty_info(uid=int(target))
|
||||
room_data = data[target]
|
||||
return self.Info.parse_obj(room_data)
|
||||
data = res_dict.get("data", {})
|
||||
infos = []
|
||||
for target in targets:
|
||||
if target in data.keys():
|
||||
infos.append(self.Info.parse_obj(data[target]))
|
||||
else:
|
||||
infos.append(self._gen_empty_info(int(target)))
|
||||
return infos
|
||||
|
||||
def compare_status(self, _: Target, old_status: Info, new_status: Info) -> list[RawPost]:
|
||||
action = Bilibililive.LiveAction
|
||||
|
||||
@@ -2,11 +2,11 @@ import ssl
|
||||
import json
|
||||
import time
|
||||
import typing
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Collection
|
||||
from typing import Any, TypeVar, ParamSpec
|
||||
from collections.abc import Callable, Awaitable, Collection
|
||||
|
||||
import httpx
|
||||
from httpx import AsyncClient
|
||||
@@ -16,7 +16,7 @@ from nonebot_plugin_saa import PlatformTarget
|
||||
from ..post import Post
|
||||
from ..plugin_config import plugin_config
|
||||
from ..utils import ProcessContext, SchedulerConfig
|
||||
from ..types import Tag, Target, RawPost, Category, UserSubInfo
|
||||
from ..types import Tag, Target, RawPost, SubUnit, Category
|
||||
|
||||
|
||||
class CategoryNotSupport(Exception):
|
||||
@@ -44,6 +44,26 @@ class RegistryMeta(type):
|
||||
super().__init__(name, bases, namespace, **kwargs)
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
async def catch_network_error(func: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs) -> R | None:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except httpx.RequestError as err:
|
||||
if plugin_config.bison_show_network_warning:
|
||||
logger.warning(f"network connection error: {type(err)}, url: {err.request.url}")
|
||||
return None
|
||||
except ssl.SSLError as err:
|
||||
if plugin_config.bison_show_network_warning:
|
||||
logger.warning(f"ssl error: {err}")
|
||||
return None
|
||||
except json.JSONDecodeError as err:
|
||||
logger.warning(f"json error, parsing: {err.doc}")
|
||||
raise err
|
||||
|
||||
|
||||
class PlatformMeta(RegistryMeta):
|
||||
categories: dict[Category, str]
|
||||
store: dict[Target, Any]
|
||||
@@ -75,6 +95,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
||||
registry: list[type["Platform"]]
|
||||
client: AsyncClient
|
||||
reverse_category: dict[str, Category]
|
||||
use_batch: bool = False
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@@ -82,25 +103,18 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
...
|
||||
|
||||
async def do_fetch_new_post(
|
||||
self, target: Target, users: list[UserSubInfo]
|
||||
) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
try:
|
||||
return await self.fetch_new_post(target, users)
|
||||
except httpx.RequestError as err:
|
||||
if plugin_config.bison_show_network_warning:
|
||||
logger.warning(f"network connection error: {type(err)}, url: {err.request.url}")
|
||||
return []
|
||||
except ssl.SSLError as err:
|
||||
if plugin_config.bison_show_network_warning:
|
||||
logger.warning(f"ssl error: {err}")
|
||||
return []
|
||||
except json.JSONDecodeError as err:
|
||||
logger.warning(f"json error, parsing: {err.doc}")
|
||||
raise err
|
||||
async def do_fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
return await catch_network_error(self.fetch_new_post, sub_unit) or []
|
||||
|
||||
@abstractmethod
|
||||
async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
...
|
||||
|
||||
async def do_batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
return await catch_network_error(self.batch_fetch_new_post, sub_units) or []
|
||||
|
||||
@abstractmethod
|
||||
async def parse(self, raw_post: RawPost) -> Post:
|
||||
@@ -190,10 +204,10 @@ class Platform(metaclass=PlatformABCMeta, base=True):
|
||||
return res
|
||||
|
||||
async def dispatch_user_post(
|
||||
self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]
|
||||
self, new_posts: list[RawPost], sub_unit: SubUnit
|
||||
) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
res: list[tuple[PlatformTarget, list[Post]]] = []
|
||||
for user, cats, required_tags in users:
|
||||
for user, cats, required_tags in sub_unit.user_sub_infos:
|
||||
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
|
||||
user_post: list[Post] = []
|
||||
for raw_post in user_raw_post:
|
||||
@@ -235,6 +249,12 @@ class MessageProcess(Platform, abstract=True):
|
||||
@abstractmethod
|
||||
async def get_sub_list(self, target: Target) -> list[RawPost]:
|
||||
"Get post list of the given target"
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def batch_get_sub_list(self, targets: list[Target]) -> list[list[RawPost]]:
|
||||
"Get post list of the given targets"
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_date(self, post: RawPost) -> int | None:
|
||||
@@ -298,9 +318,12 @@ class NewMessage(MessageProcess, abstract=True):
|
||||
self.set_stored_data(target, store)
|
||||
return res
|
||||
|
||||
async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
post_list = await self.get_sub_list(target)
|
||||
new_posts = await self.filter_common_with_diff(target, post_list)
|
||||
async def _handle_new_post(
|
||||
self,
|
||||
post_list: list[RawPost],
|
||||
sub_unit: SubUnit,
|
||||
) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
new_posts = await self.filter_common_with_diff(sub_unit.sub_target, post_list)
|
||||
if not new_posts:
|
||||
return []
|
||||
else:
|
||||
@@ -308,14 +331,27 @@ class NewMessage(MessageProcess, abstract=True):
|
||||
logger.info(
|
||||
"fetch new post from {} {}: {}".format(
|
||||
self.platform_name,
|
||||
target if self.has_target else "-",
|
||||
sub_unit.sub_target if self.has_target else "-",
|
||||
self.get_id(post),
|
||||
)
|
||||
)
|
||||
res = await self.dispatch_user_post(target, new_posts, users)
|
||||
res = await self.dispatch_user_post(new_posts, sub_unit)
|
||||
self.parse_cache = {}
|
||||
return res
|
||||
|
||||
async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
post_list = await self.get_sub_list(sub_unit.sub_target)
|
||||
return await self._handle_new_post(post_list, sub_unit)
|
||||
|
||||
async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
if not self.has_target:
|
||||
raise RuntimeError("Target without target should not use batch api") # pragma: no cover
|
||||
posts_set = await self.batch_get_sub_list([x[0] for x in sub_units])
|
||||
res = []
|
||||
for sub_unit, posts in zip(sub_units, posts_set):
|
||||
res.extend(await self._handle_new_post(posts, sub_unit))
|
||||
return res
|
||||
|
||||
|
||||
class StatusChange(Platform, abstract=True):
|
||||
"Watch a status, and fire a post when status changes"
|
||||
@@ -327,6 +363,10 @@ class StatusChange(Platform, abstract=True):
|
||||
async def get_status(self, target: Target) -> Any:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def batch_get_status(self, targets: list[Target]) -> list[Any]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]:
|
||||
...
|
||||
@@ -335,34 +375,51 @@ class StatusChange(Platform, abstract=True):
|
||||
async def parse(self, raw_post: RawPost) -> Post:
|
||||
...
|
||||
|
||||
async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
try:
|
||||
new_status = await self.get_status(target)
|
||||
except self.FetchError as err:
|
||||
logger.warning(f"fetching {self.name}-{target} error: {err}")
|
||||
raise
|
||||
async def _handle_status_change(
|
||||
self, new_status: Any, sub_unit: SubUnit
|
||||
) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
res = []
|
||||
if old_status := self.get_stored_data(target):
|
||||
diff = self.compare_status(target, old_status, new_status)
|
||||
if old_status := self.get_stored_data(sub_unit.sub_target):
|
||||
diff = self.compare_status(sub_unit.sub_target, old_status, new_status)
|
||||
if diff:
|
||||
logger.info(
|
||||
"status changes {} {}: {} -> {}".format(
|
||||
self.platform_name,
|
||||
target if self.has_target else "-",
|
||||
sub_unit.sub_target if self.has_target else "-",
|
||||
old_status,
|
||||
new_status,
|
||||
)
|
||||
)
|
||||
res = await self.dispatch_user_post(target, diff, users)
|
||||
self.set_stored_data(target, new_status)
|
||||
res = await self.dispatch_user_post(diff, sub_unit)
|
||||
self.set_stored_data(sub_unit.sub_target, new_status)
|
||||
return res
|
||||
|
||||
async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
try:
|
||||
new_status = await self.get_status(sub_unit.sub_target)
|
||||
except self.FetchError as err:
|
||||
logger.warning(f"fetching {self.name}-{sub_unit.sub_target} error: {err}")
|
||||
raise
|
||||
return await self._handle_status_change(new_status, sub_unit)
|
||||
|
||||
async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
if not self.has_target:
|
||||
raise RuntimeError("Target without target should not use batch api") # pragma: no cover
|
||||
new_statuses = await self.batch_get_status([x[0] for x in sub_units])
|
||||
res = []
|
||||
for sub_unit, new_status in zip(sub_units, new_statuses):
|
||||
res.extend(await self._handle_status_change(new_status, sub_unit))
|
||||
return res
|
||||
|
||||
|
||||
class SimplePost(MessageProcess, abstract=True):
|
||||
class SimplePost(NewMessage, abstract=True):
|
||||
"Fetch a list of messages, dispatch it to different users"
|
||||
|
||||
async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
new_posts = await self.get_sub_list(target)
|
||||
async def _handle_new_post(
|
||||
self,
|
||||
new_posts: list[RawPost],
|
||||
sub_unit: SubUnit,
|
||||
) -> list[tuple[PlatformTarget, list[Post]]]:
|
||||
if not new_posts:
|
||||
return []
|
||||
else:
|
||||
@@ -370,11 +427,11 @@ class SimplePost(MessageProcess, abstract=True):
|
||||
logger.info(
|
||||
"fetch new post from {} {}: {}".format(
|
||||
self.platform_name,
|
||||
target if self.has_target else "-",
|
||||
sub_unit.sub_target if self.has_target else "-",
|
||||
self.get_id(post),
|
||||
)
|
||||
)
|
||||
res = await self.dispatch_user_post(target, new_posts, users)
|
||||
res = await self.dispatch_user_post(new_posts, sub_unit)
|
||||
self.parse_cache = {}
|
||||
return res
|
||||
|
||||
@@ -422,10 +479,10 @@ def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]:
|
||||
async def get_target_name(cls, client: AsyncClient, target: Target):
|
||||
return await platform_list[0].get_target_name(client, target)
|
||||
|
||||
async def fetch_new_post(self: "NoTargetGroup", target: Target, users: list[UserSubInfo]):
|
||||
async def fetch_new_post(self: "NoTargetGroup", sub_unit: SubUnit):
|
||||
res = defaultdict(list)
|
||||
for platform in self.platform_obj_list:
|
||||
platform_res = await platform.fetch_new_post(target=target, users=users)
|
||||
platform_res = await platform.fetch_new_post(sub_unit)
|
||||
for user, posts in platform_res:
|
||||
res[user].extend(posts)
|
||||
return [[key, val] for key, val in res.items()]
|
||||
|
||||
Reference in New Issue
Block a user