505 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import ssl
import json
import time
import typing
from dataclasses import dataclass
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, TypeVar, ParamSpec
from collections.abc import Callable, Awaitable, Collection
import httpx
from httpx import AsyncClient
from nonebot.log import logger
from nonebot_plugin_saa import PlatformTarget
from ..post import Post
from ..utils import Site, ProcessContext
from ..plugin_config import plugin_config
from ..types import Tag, Target, RawPost, SubUnit, Category
class CategoryNotSupport(Exception):
"""raise in get_category, when you know the category of the post
but don't want to support it or don't support its parsing yet
"""
class CategoryNotRecognize(Exception):
"""raise in get_category, when you don't know the category of post"""
class RegistryMeta(type):
def __new__(cls, name, bases, namespace, **kwargs):
return super().__new__(cls, name, bases, namespace)
def __init__(cls, name, bases, namespace, **kwargs):
if kwargs.get("base"):
# this is the base class
cls.registry = []
elif not kwargs.get("abstract"):
# this is the subclass
cls.registry.append(cls)
super().__init__(name, bases, namespace, **kwargs)
P = ParamSpec("P")
R = TypeVar("R")
async def catch_network_error(func: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs) -> R | None:
try:
return await func(*args, **kwargs)
except httpx.RequestError as err:
if plugin_config.bison_show_network_warning:
logger.warning(f"network connection error: {type(err)}, url: {err.request.url}")
return None
except ssl.SSLError as err:
if plugin_config.bison_show_network_warning:
logger.warning(f"ssl error: {err}")
return None
except json.JSONDecodeError as err:
logger.warning(f"json error, parsing: {err.doc}")
raise err
class PlatformMeta(RegistryMeta):
categories: dict[Category, str]
store: dict[Target, Any]
def __init__(cls, name, bases, namespace, **kwargs):
cls.reverse_category = {}
cls.store = {}
if hasattr(cls, "categories") and cls.categories:
for key, val in cls.categories.items():
cls.reverse_category[val] = key
super().__init__(name, bases, namespace, **kwargs)
class PlatformABCMeta(PlatformMeta, ABC): ...
class Platform(metaclass=PlatformABCMeta, base=True):
site: type[Site]
ctx: ProcessContext
is_common: bool
enabled: bool
name: str
has_target: bool
categories: dict[Category, str]
enable_tag: bool
platform_name: str
parse_target_promot: str | None = None
registry: list[type["Platform"]]
reverse_category: dict[str, Category]
use_batch: bool = False
# TODO: 限定可使用的theme名称
default_theme: str = "basic"
@classmethod
@abstractmethod
async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: ...
@abstractmethod
async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]: ...
async def do_fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
return await catch_network_error(self.fetch_new_post, sub_unit) or []
@abstractmethod
async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]: ...
async def do_batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
return await catch_network_error(self.batch_fetch_new_post, sub_units) or []
@abstractmethod
async def parse(self, raw_post: RawPost) -> Post: ...
async def do_parse(self, raw_post: RawPost) -> Post:
"actually function called"
return await self.parse(raw_post)
def __init__(self, context: ProcessContext):
super().__init__()
self.ctx = context
class ParseTargetException(Exception):
def __init__(self, *args: object, prompt: str | None = None) -> None:
super().__init__(*args)
self.prompt = prompt
"""用户输入提示信息"""
@classmethod
async def parse_target(cls, target_string: str) -> Target:
return Target(target_string)
@abstractmethod
def get_tags(self, raw_post: RawPost) -> Collection[Tag] | None:
"Return Tag list of given RawPost"
@classmethod
def get_stored_data(cls, target: Target) -> Any:
return cls.store.get(target)
@classmethod
def set_stored_data(cls, target: Target, data: Any):
cls.store[target] = data
def tag_separator(self, stored_tags: list[Tag]) -> tuple[list[Tag], list[Tag]]:
"""返回分离好的正反tag元组"""
subscribed_tags = []
banned_tags = []
for tag in stored_tags:
if tag.startswith("~"):
banned_tags.append(tag.lstrip("~"))
else:
subscribed_tags.append(tag)
return subscribed_tags, banned_tags
def is_banned_post(
self,
post_tags: Collection[Tag],
subscribed_tags: list[Tag],
banned_tags: list[Tag],
) -> bool:
"""只要存在任意屏蔽tag则返回真此行为优先级最高。
存在任意被订阅tag则返回假此行为优先级次之。
若被订阅tag为空则返回假。
"""
# 存在任意需要屏蔽的tag则为真
if banned_tags:
for tag in post_tags or []:
if tag in banned_tags:
return True
# 检测屏蔽tag后再检测订阅tag
# 存在任意需要订阅的tag则为假
if subscribed_tags:
ban_it = True
for tag in post_tags or []:
if tag in subscribed_tags:
ban_it = False
return ban_it
else:
return False
async def filter_user_custom(
self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]
) -> list[RawPost]:
res: list[RawPost] = []
for raw_post in raw_post_list:
if self.categories:
cat = self.get_category(raw_post)
if cats and cat not in cats:
continue
if self.enable_tag and tags:
raw_post_tags = self.get_tags(raw_post)
if isinstance(raw_post_tags, Collection) and self.is_banned_post(
raw_post_tags, *self.tag_separator(tags)
):
continue
res.append(raw_post)
return res
async def dispatch_user_post(
self, new_posts: list[RawPost], sub_unit: SubUnit
) -> list[tuple[PlatformTarget, list[Post]]]:
res: list[tuple[PlatformTarget, list[Post]]] = []
for user, cats, required_tags in sub_unit.user_sub_infos:
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
user_post: list[Post] = []
for raw_post in user_raw_post:
user_post.append(await self.do_parse(raw_post))
res.append((user, user_post))
return res
@abstractmethod
def get_category(self, post: RawPost) -> Category | None:
"Return category of given Rawpost"
raise NotImplementedError()
class MessageProcess(Platform, abstract=True):
"General message process fetch, parse, filter progress"
def __init__(self, ctx: ProcessContext):
super().__init__(ctx)
self.parse_cache: dict[Any, Post] = {}
@abstractmethod
def get_id(self, post: RawPost) -> Any:
"Get post id of given RawPost"
async def do_parse(self, raw_post: RawPost) -> Post:
post_id = self.get_id(raw_post)
if post_id not in self.parse_cache:
retry_times = 3
while retry_times:
try:
self.parse_cache[post_id] = await self.parse(raw_post)
break
except Exception as err:
retry_times -= 1
if not retry_times:
raise err
return self.parse_cache[post_id]
@abstractmethod
async def get_sub_list(self, target: Target) -> list[RawPost]:
"Get post list of the given target"
raise NotImplementedError()
@abstractmethod
async def batch_get_sub_list(self, targets: list[Target]) -> list[list[RawPost]]:
"Get post list of the given targets"
raise NotImplementedError()
@abstractmethod
def get_date(self, post: RawPost) -> int | None:
"Get post timestamp and return, return None if can't get the time"
async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]:
res = []
for raw_post in raw_post_list:
# post_id = self.get_id(raw_post)
# if post_id in exists_posts_set:
# continue
if (
(post_time := self.get_date(raw_post))
and time.time() - post_time > 2 * 60 * 60
and plugin_config.bison_init_filter
):
continue
try:
self.get_category(raw_post)
except CategoryNotSupport as e:
logger.info("未支持解析的推文类别:" + repr(e) + ",忽略")
continue
except CategoryNotRecognize as e:
logger.warning("未知推文类别:" + repr(e))
msgs = self.ctx.gen_req_records()
for m in msgs:
logger.warning(m)
continue
except NotImplementedError:
pass
res.append(raw_post)
return res
class NewMessage(MessageProcess, abstract=True):
"Fetch a list of messages, filter the new messages, dispatch it to different users"
@dataclass
class MessageStorage:
inited: bool
exists_posts: set[Any]
async def filter_common_with_diff(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]:
filtered_post = await self.filter_common(raw_post_list)
store = self.get_stored_data(target) or self.MessageStorage(False, set())
res = []
if not store.inited and plugin_config.bison_init_filter:
# target not init
for raw_post in filtered_post:
post_id = self.get_id(raw_post)
store.exists_posts.add(post_id)
logger.info(f"init {self.platform_name}-{target} with {store.exists_posts}")
store.inited = True
else:
for raw_post in filtered_post:
post_id = self.get_id(raw_post)
if post_id in store.exists_posts:
continue
res.append(raw_post)
store.exists_posts.add(post_id)
self.set_stored_data(target, store)
logger.trace(f"本次抓取 {len(raw_post_list)} 条,过滤后 {len(filtered_post)} 条,新消息 {len(res)}")
return res
async def _handle_new_post(
self,
post_list: list[RawPost],
sub_unit: SubUnit,
) -> list[tuple[PlatformTarget, list[Post]]]:
new_posts = await self.filter_common_with_diff(sub_unit.sub_target, post_list)
if not new_posts:
return []
else:
for post in new_posts:
logger.info(
"fetch new post from {} {}: {}".format(
self.platform_name,
sub_unit.sub_target if self.has_target else "-",
self.get_id(post),
)
)
res = await self.dispatch_user_post(new_posts, sub_unit)
self.parse_cache = {}
return res
async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
post_list = await self.get_sub_list(sub_unit.sub_target)
return await self._handle_new_post(post_list, sub_unit)
async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
if not self.has_target:
raise RuntimeError("Target without target should not use batch api") # pragma: no cover
posts_set = await self.batch_get_sub_list([x[0] for x in sub_units])
res = []
for sub_unit, posts in zip(sub_units, posts_set):
res.extend(await self._handle_new_post(posts, sub_unit))
return res
class StatusChange(Platform, abstract=True):
"Watch a status, and fire a post when status changes"
class FetchError(RuntimeError):
pass
@abstractmethod
async def get_status(self, target: Target) -> Any: ...
@abstractmethod
async def batch_get_status(self, targets: list[Target]) -> list[Any]: ...
@abstractmethod
def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]: ...
@abstractmethod
async def parse(self, raw_post: RawPost) -> Post: ...
async def _handle_status_change(
self, new_status: Any, sub_unit: SubUnit
) -> list[tuple[PlatformTarget, list[Post]]]:
res = []
if old_status := self.get_stored_data(sub_unit.sub_target):
diff = self.compare_status(sub_unit.sub_target, old_status, new_status)
if diff:
logger.info(
"status changes {} {}: {} -> {}".format(
self.platform_name,
sub_unit.sub_target if self.has_target else "-",
old_status,
new_status,
)
)
res = await self.dispatch_user_post(diff, sub_unit)
self.set_stored_data(sub_unit.sub_target, new_status)
return res
async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]:
try:
new_status = await self.get_status(sub_unit.sub_target)
except self.FetchError as err:
logger.warning(f"fetching {self.name}-{sub_unit.sub_target} error: {err}")
raise
return await self._handle_status_change(new_status, sub_unit)
async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]:
if not self.has_target:
raise RuntimeError("Target without target should not use batch api") # pragma: no cover
new_statuses = await self.batch_get_status([x[0] for x in sub_units])
res = []
for sub_unit, new_status in zip(sub_units, new_statuses):
res.extend(await self._handle_status_change(new_status, sub_unit))
return res
class SimplePost(NewMessage, abstract=True):
"Fetch a list of messages, dispatch it to different users"
async def _handle_new_post(
self,
post_list: list[RawPost],
sub_unit: SubUnit,
) -> list[tuple[PlatformTarget, list[Post]]]:
if not post_list:
return []
else:
for post in post_list:
logger.info(
"fetch new post from {} {}: {}".format(
self.platform_name,
sub_unit.sub_target if self.has_target else "-",
self.get_id(post),
)
)
res = await self.dispatch_user_post(post_list, sub_unit)
self.parse_cache = {}
return res
def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]:
if typing.TYPE_CHECKING:
class NoTargetGroup(Platform, abstract=True):
platform_list: list[type[Platform]]
platform_obj_list: list[Platform]
DUMMY_STR = "_DUMMY"
platform_name = platform_list[0].platform_name
name = DUMMY_STR
categories_keys = set()
categories = {}
site = platform_list[0].site
for platform in platform_list:
if platform.has_target:
raise RuntimeError(f"Platform {platform.name} should have no target")
if name == DUMMY_STR:
name = platform.name
elif name != platform.name:
raise RuntimeError(f"Platform name for {platform_name} not fit")
platform_category_key_set = set(platform.categories.keys())
if platform_category_key_set & categories_keys:
raise RuntimeError(f"Platform categories for {platform_name} duplicate")
categories_keys |= platform_category_key_set
categories.update(platform.categories)
if platform.site != site:
raise RuntimeError(f"Platform scheduler for {platform_name} not fit")
def __init__(self: "NoTargetGroup", ctx: ProcessContext):
Platform.__init__(self, ctx)
self.platform_obj_list = []
for platform_class in self.platform_list:
self.platform_obj_list.append(platform_class(ctx))
def __str__(self: "NoTargetGroup") -> str:
return "[" + " ".join(x.name for x in self.platform_list) + "]"
@classmethod
async def get_target_name(cls, client: AsyncClient, target: Target):
return await platform_list[0].get_target_name(client, target)
async def fetch_new_post(self: "NoTargetGroup", sub_unit: SubUnit):
res = defaultdict(list)
for platform in self.platform_obj_list:
platform_res = await platform.fetch_new_post(sub_unit)
for user, posts in platform_res:
res[user].extend(posts)
return [[key, val] for key, val in res.items()]
return type(
"NoTargetGroup",
(Platform,),
{
"platform_list": platform_list,
"platform_name": platform_list[0].platform_name,
"name": name,
"categories": categories,
"site": site,
"is_common": platform_list[0].is_common,
"enabled": True,
"has_target": False,
"enable_tag": False,
"__init__": __init__,
"get_target_name": get_target_name,
"fetch_new_post": fetch_new_post,
},
abstract=True,
)