提供批量 api 接口 (#290)

* 🚧 add batch api

* 🚧 support batch in scheduler

*  use batch api in bilibili-live

 patch platform_manager directly

 use batch api in bilibili-live

 patch platform_manager directly

* ♻️ refactor var name

* 🐛 fix test

* 🐛 fix scheduler

* 🐛 fix test
This commit is contained in:
felinae98
2023-08-29 21:12:42 +08:00
committed by GitHub
parent 219e3ba5c6
commit e7dcfdee50
16 changed files with 519 additions and 166 deletions
+11 -8
View File
@@ -201,6 +201,7 @@ class Bilibililive(StatusChange):
scheduler = BilibiliSchedConf
name = "Bilibili直播"
has_target = True
use_batch = True
@unique
class LiveStatus(Enum):
@@ -281,12 +282,11 @@ class Bilibililive(StatusChange):
keyframe="",
)
async def get_status(self, target: Target) -> Info:
params = {"uids[]": target}
async def batch_get_status(self, targets: list[Target]) -> list[Info]:
# https://github.com/SocialSisterYi/bilibili-API-collect/blob/master/docs/live/info.md#批量查询直播间状态
res = await self.client.get(
"https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids",
params=params,
params={"uids[]": targets},
timeout=4.0,
)
res_dict = res.json()
@@ -294,11 +294,14 @@ class Bilibililive(StatusChange):
if res_dict["code"] != 0:
raise self.FetchError()
data = res_dict.get("data")
if not data:
return self._gen_empty_info(uid=int(target))
room_data = data[target]
return self.Info.parse_obj(room_data)
data = res_dict.get("data", {})
infos = []
for target in targets:
if target in data.keys():
infos.append(self.Info.parse_obj(data[target]))
else:
infos.append(self._gen_empty_info(int(target)))
return infos
def compare_status(self, _: Target, old_status: Info, new_status: Info) -> list[RawPost]:
action = Bilibililive.LiveAction
+102 -45
View File
@@ -2,11 +2,11 @@ import ssl
import json
import time
import typing
from typing import Any
from dataclasses import dataclass
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Collection
from typing import Any, TypeVar, ParamSpec
from collections.abc import Callable, Awaitable, Collection
import httpx
from httpx import AsyncClient
@@ -16,7 +16,7 @@ from nonebot_plugin_saa import PlatformTarget
from ..post import Post
from ..plugin_config import plugin_config
from ..utils import ProcessContext, SchedulerConfig
from ..types import Tag, Target, RawPost, Category, UserSubInfo
from ..types import Tag, Target, RawPost, SubUnit, Category
class CategoryNotSupport(Exception):
@@ -44,6 +44,26 @@ class RegistryMeta(type):
super().__init__(name, bases, namespace, **kwargs)
P = ParamSpec("P")
R = TypeVar("R")
async def catch_network_error(func: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs) -> R | None:
try:
return await func(*args, **kwargs)
except httpx.RequestError as err:
if plugin_config.bison_show_network_warning:
logger.warning(f"network connection error: {type(err)}, url: {err.request.url}")
return None
except ssl.SSLError as err:
if plugin_config.bison_show_network_warning:
logger.warning(f"ssl error: {err}")
return None
except json.JSONDecodeError as err:
logger.warning(f"json error, parsing: {err.doc}")
raise err
class PlatformMeta(RegistryMeta):
categories: dict[Category, str]
store: dict[Target, Any]
@@ -75,6 +95,7 @@ class Platform(metaclass=PlatformABCMeta, base=True):
registry: list[type["Platform"]]
client: AsyncClient
reverse_category: dict[str, Category]
use_batch: bool = False
@classmethod
@abstractmethod
@@ -82,25 +103,18 @@ class Platform(metaclass=PlatformABCMeta, base=True):
...
@abstractmethod
async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]:
async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
...
async def do_fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[PlatformTarget, list[Post]]]:
try:
return await self.fetch_new_post(target, users)
except httpx.RequestError as err:
if plugin_config.bison_show_network_warning:
logger.warning(f"network connection error: {type(err)}, url: {err.request.url}")
return []
except ssl.SSLError as err:
if plugin_config.bison_show_network_warning:
logger.warning(f"ssl error: {err}")
return []
except json.JSONDecodeError as err:
logger.warning(f"json error, parsing: {err.doc}")
raise err
async def do_fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
return await catch_network_error(self.fetch_new_post, sub_unit) or []
@abstractmethod
async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
...
async def do_batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
return await catch_network_error(self.batch_fetch_new_post, sub_units) or []
@abstractmethod
async def parse(self, raw_post: RawPost) -> Post:
@@ -190,10 +204,10 @@ class Platform(metaclass=PlatformABCMeta, base=True):
return res
async def dispatch_user_post(
self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]
self, new_posts: list[RawPost], sub_unit: SubUnit
) -> list[tuple[PlatformTarget, list[Post]]]:
res: list[tuple[PlatformTarget, list[Post]]] = []
for user, cats, required_tags in users:
for user, cats, required_tags in sub_unit.user_sub_infos:
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
user_post: list[Post] = []
for raw_post in user_raw_post:
@@ -235,6 +249,12 @@ class MessageProcess(Platform, abstract=True):
@abstractmethod
async def get_sub_list(self, target: Target) -> list[RawPost]:
"Get post list of the given target"
raise NotImplementedError()
@abstractmethod
async def batch_get_sub_list(self, targets: list[Target]) -> list[list[RawPost]]:
"Get post list of the given targets"
raise NotImplementedError()
@abstractmethod
def get_date(self, post: RawPost) -> int | None:
@@ -298,9 +318,12 @@ class NewMessage(MessageProcess, abstract=True):
self.set_stored_data(target, store)
return res
async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]:
post_list = await self.get_sub_list(target)
new_posts = await self.filter_common_with_diff(target, post_list)
async def _handle_new_post(
self,
post_list: list[RawPost],
sub_unit: SubUnit,
) -> list[tuple[PlatformTarget, list[Post]]]:
new_posts = await self.filter_common_with_diff(sub_unit.sub_target, post_list)
if not new_posts:
return []
else:
@@ -308,14 +331,27 @@ class NewMessage(MessageProcess, abstract=True):
logger.info(
"fetch new post from {} {}: {}".format(
self.platform_name,
target if self.has_target else "-",
sub_unit.sub_target if self.has_target else "-",
self.get_id(post),
)
)
res = await self.dispatch_user_post(target, new_posts, users)
res = await self.dispatch_user_post(new_posts, sub_unit)
self.parse_cache = {}
return res
async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
post_list = await self.get_sub_list(sub_unit.sub_target)
return await self._handle_new_post(post_list, sub_unit)
async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
if not self.has_target:
raise RuntimeError("Target without target should not use batch api") # pragma: no cover
posts_set = await self.batch_get_sub_list([x[0] for x in sub_units])
res = []
for sub_unit, posts in zip(sub_units, posts_set):
res.extend(await self._handle_new_post(posts, sub_unit))
return res
class StatusChange(Platform, abstract=True):
"Watch a status, and fire a post when status changes"
@@ -327,6 +363,10 @@ class StatusChange(Platform, abstract=True):
async def get_status(self, target: Target) -> Any:
...
@abstractmethod
async def batch_get_status(self, targets: list[Target]) -> list[Any]:
...
@abstractmethod
def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]:
...
@@ -335,34 +375,51 @@ class StatusChange(Platform, abstract=True):
async def parse(self, raw_post: RawPost) -> Post:
...
async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]:
try:
new_status = await self.get_status(target)
except self.FetchError as err:
logger.warning(f"fetching {self.name}-{target} error: {err}")
raise
async def _handle_status_change(
self, new_status: Any, sub_unit: SubUnit
) -> list[tuple[PlatformTarget, list[Post]]]:
res = []
if old_status := self.get_stored_data(target):
diff = self.compare_status(target, old_status, new_status)
if old_status := self.get_stored_data(sub_unit.sub_target):
diff = self.compare_status(sub_unit.sub_target, old_status, new_status)
if diff:
logger.info(
"status changes {} {}: {} -> {}".format(
self.platform_name,
target if self.has_target else "-",
sub_unit.sub_target if self.has_target else "-",
old_status,
new_status,
)
)
res = await self.dispatch_user_post(target, diff, users)
self.set_stored_data(target, new_status)
res = await self.dispatch_user_post(diff, sub_unit)
self.set_stored_data(sub_unit.sub_target, new_status)
return res
async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
try:
new_status = await self.get_status(sub_unit.sub_target)
except self.FetchError as err:
logger.warning(f"fetching {self.name}-{sub_unit.sub_target} error: {err}")
raise
return await self._handle_status_change(new_status, sub_unit)
async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
if not self.has_target:
raise RuntimeError("Target without target should not use batch api") # pragma: no cover
new_statuses = await self.batch_get_status([x[0] for x in sub_units])
res = []
for sub_unit, new_status in zip(sub_units, new_statuses):
res.extend(await self._handle_status_change(new_status, sub_unit))
return res
class SimplePost(MessageProcess, abstract=True):
class SimplePost(NewMessage, abstract=True):
"Fetch a list of messages, dispatch it to different users"
async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[PlatformTarget, list[Post]]]:
new_posts = await self.get_sub_list(target)
async def _handle_new_post(
self,
new_posts: list[RawPost],
sub_unit: SubUnit,
) -> list[tuple[PlatformTarget, list[Post]]]:
if not new_posts:
return []
else:
@@ -370,11 +427,11 @@ class SimplePost(MessageProcess, abstract=True):
logger.info(
"fetch new post from {} {}: {}".format(
self.platform_name,
target if self.has_target else "-",
sub_unit.sub_target if self.has_target else "-",
self.get_id(post),
)
)
res = await self.dispatch_user_post(target, new_posts, users)
res = await self.dispatch_user_post(new_posts, sub_unit)
self.parse_cache = {}
return res
@@ -422,10 +479,10 @@ def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]:
async def get_target_name(cls, client: AsyncClient, target: Target):
return await platform_list[0].get_target_name(client, target)
async def fetch_new_post(self: "NoTargetGroup", target: Target, users: list[UserSubInfo]):
async def fetch_new_post(self: "NoTargetGroup", sub_unit: SubUnit):
res = defaultdict(list)
for platform in self.platform_obj_list:
platform_res = await platform.fetch_new_post(target=target, users=users)
platform_res = await platform.fetch_new_post(sub_unit)
for user, posts in platform_res:
res[user].extend(posts)
return [[key, val] for key, val in res.items()]
+3 -1
View File
@@ -29,7 +29,9 @@ async def init_scheduler():
for scheduler_config, target_list in _schedule_class_dict.items():
schedulable_args = []
for target in target_list:
schedulable_args.append((target.platform_name, T_Target(target.target)))
schedulable_args.append(
(target.platform_name, T_Target(target.target), platform_manager[target.platform_name].use_batch)
)
platform_name_list = _schedule_class_platform_dict[scheduler_config]
scheduler_dict[scheduler_config] = Scheduler(scheduler_config, schedulable_args, platform_name_list)
config.register_add_target_hook(handle_insert_new_target)
+44 -8
View File
@@ -1,12 +1,13 @@
from dataclasses import dataclass
from collections import defaultdict
from nonebot.log import logger
from nonebot_plugin_apscheduler import scheduler
from nonebot_plugin_saa.utils.exceptions import NoBotFound
from ..types import Target
from ..config import config
from ..send import send_msgs
from ..types import Target, SubUnit
from ..platform import platform_manager
from ..utils import ProcessContext, SchedulerConfig
@@ -16,15 +17,18 @@ class Schedulable:
platform_name: str
target: Target
current_weight: int
use_batch: bool = False
class Scheduler:
schedulable_list: list[Schedulable]
schedulable_list: list[Schedulable] # for load weigth from db
batch_api_target_cache: dict[str, dict[Target, list[Target]]] # platform_name -> (target -> [target])
batch_platform_name_targets_cache: dict[str, list[Target]]
def __init__(
self,
scheduler_config: type[SchedulerConfig],
schedulables: list[tuple[str, Target]],
schedulables: list[tuple[str, Target, bool]], # [(platform_name, target, use_batch)]
platform_name_list: list[str],
):
self.name = scheduler_config.name
@@ -33,9 +37,17 @@ class Scheduler:
raise RuntimeError(f"{self.name} not found")
self.scheduler_config = scheduler_config
self.scheduler_config_obj = self.scheduler_config()
self.schedulable_list = []
for platform_name, target in schedulables:
self.schedulable_list.append(Schedulable(platform_name=platform_name, target=target, current_weight=0))
self.batch_platform_name_targets_cache: dict[str, list[Target]] = defaultdict(list)
for platform_name, target, use_batch in schedulables:
if use_batch:
self.batch_platform_name_targets_cache[platform_name].append(target)
self.schedulable_list.append(
Schedulable(platform_name=platform_name, target=target, current_weight=0, use_batch=use_batch)
)
self._refresh_batch_api_target_cache()
self.platform_name_list = platform_name_list
self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值
logger.info(
@@ -48,6 +60,12 @@ class Scheduler:
**self.scheduler_config.schedule_setting,
)
def _refresh_batch_api_target_cache(self):
self.batch_api_target_cache = defaultdict(dict)
for platform_name, targets in self.batch_platform_name_targets_cache.items():
for target in targets:
self.batch_api_target_cache[platform_name][target] = targets
async def get_next_schedulable(self) -> Schedulable | None:
if not self.schedulable_list:
return None
@@ -69,14 +87,24 @@ class Scheduler:
if not (schedulable := await self.get_next_schedulable()):
return
logger.trace(f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}")
send_userinfo_list = await config.get_platform_target_subscribers(schedulable.platform_name, schedulable.target)
client = await self.scheduler_config_obj.get_client(schedulable.target)
context.register_to_client(client)
try:
platform_obj = platform_manager[schedulable.platform_name](context, client)
to_send = await platform_obj.do_fetch_new_post(schedulable.target, send_userinfo_list)
if schedulable.use_batch:
batch_targets = self.batch_api_target_cache[schedulable.platform_name][schedulable.target]
sub_units = []
for batch_target in batch_targets:
userinfo = await config.get_platform_target_subscribers(schedulable.platform_name, batch_target)
sub_units.append(SubUnit(batch_target, userinfo))
to_send = await platform_obj.do_batch_fetch_new_post(sub_units)
else:
send_userinfo_list = await config.get_platform_target_subscribers(
schedulable.platform_name, schedulable.target
)
to_send = await platform_obj.do_fetch_new_post(SubUnit(schedulable.target, send_userinfo_list))
except Exception as err:
records = context.gen_req_records()
for record in records:
@@ -101,9 +129,18 @@ class Scheduler:
def insert_new_schedulable(self, platform_name: str, target: Target):
self.pre_weight_val += 1000
self.schedulable_list.append(Schedulable(platform_name, target, 1000))
if platform_manager[platform_name].use_batch:
self.batch_platform_name_targets_cache[platform_name].append(target)
self._refresh_batch_api_target_cache()
logger.info(f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})")
def delete_schedulable(self, platform_name, target: Target):
if platform_manager[platform_name].use_batch:
self.batch_platform_name_targets_cache[platform_name].remove(target)
self._refresh_batch_api_target_cache()
if not self.schedulable_list:
return
to_find_idx = None
@@ -114,4 +151,3 @@ class Scheduler:
if to_find_idx is not None:
deleted_schdulable = self.schedulable_list.pop(to_find_idx)
self.pre_weight_val -= deleted_schdulable.current_weight
return
+5
View File
@@ -53,3 +53,8 @@ class ApiError(Exception):
def __init__(self, url: URL) -> None:
msg = f"api {url} error"
super().__init__(msg)
class SubUnit(NamedTuple):
sub_target: Target
user_sub_infos: list[UserSubInfo]