import ssl import json import time import typing from dataclasses import dataclass from abc import ABC, abstractmethod from collections import defaultdict from typing import Any, TypeVar, ParamSpec from collections.abc import Callable, Awaitable, Collection import httpx from httpx import AsyncClient from nonebot.log import logger from nonebot_plugin_saa import PlatformTarget from ..post import Post from ..utils import Site, ProcessContext from ..plugin_config import plugin_config from ..types import Tag, Target, RawPost, SubUnit, Category class CategoryNotSupport(Exception): """raise in get_category, when you know the category of the post but don't want to support it or don't support its parsing yet """ class CategoryNotRecognize(Exception): """raise in get_category, when you don't know the category of post""" class RegistryMeta(type): def __new__(cls, name, bases, namespace, **kwargs): return super().__new__(cls, name, bases, namespace) def __init__(cls, name, bases, namespace, **kwargs): if kwargs.get("base"): # this is the base class cls.registry = [] elif not kwargs.get("abstract"): # this is the subclass cls.registry.append(cls) super().__init__(name, bases, namespace, **kwargs) P = ParamSpec("P") R = TypeVar("R") async def catch_network_error(func: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs) -> R | None: try: return await func(*args, **kwargs) except httpx.RequestError as err: if plugin_config.bison_show_network_warning: logger.warning(f"network connection error: {type(err)}, url: {err.request.url}") return None except ssl.SSLError as err: if plugin_config.bison_show_network_warning: logger.warning(f"ssl error: {err}") return None except json.JSONDecodeError as err: logger.warning(f"json error, parsing: {err.doc}") raise err class PlatformMeta(RegistryMeta): categories: dict[Category, str] store: dict[Target, Any] def __init__(cls, name, bases, namespace, **kwargs): cls.reverse_category = {} cls.store = {} if hasattr(cls, "categories") and cls.categories: for key, val in cls.categories.items(): cls.reverse_category[val] = key super().__init__(name, bases, namespace, **kwargs) class PlatformABCMeta(PlatformMeta, ABC): ... class Platform(metaclass=PlatformABCMeta, base=True): site: type[Site] ctx: ProcessContext is_common: bool enabled: bool name: str has_target: bool categories: dict[Category, str] enable_tag: bool platform_name: str parse_target_promot: str | None = None registry: list[type["Platform"]] reverse_category: dict[str, Category] use_batch: bool = False # TODO: 限定可使用的theme名称 default_theme: str = "basic" @classmethod @abstractmethod async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: ... @abstractmethod async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]: ... async def do_fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]: return await catch_network_error(self.fetch_new_post, sub_unit) or [] @abstractmethod async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]: ... async def do_batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]: return await catch_network_error(self.batch_fetch_new_post, sub_units) or [] @abstractmethod async def parse(self, raw_post: RawPost) -> Post: ... async def do_parse(self, raw_post: RawPost) -> Post: "actually function called" return await self.parse(raw_post) def __init__(self, context: ProcessContext): super().__init__() self.ctx = context class ParseTargetException(Exception): def __init__(self, *args: object, prompt: str | None = None) -> None: super().__init__(*args) self.prompt = prompt """用户输入提示信息""" @classmethod async def parse_target(cls, target_string: str) -> Target: return Target(target_string) @abstractmethod def get_tags(self, raw_post: RawPost) -> Collection[Tag] | None: "Return Tag list of given RawPost" @classmethod def get_stored_data(cls, target: Target) -> Any: return cls.store.get(target) @classmethod def set_stored_data(cls, target: Target, data: Any): cls.store[target] = data def tag_separator(self, stored_tags: list[Tag]) -> tuple[list[Tag], list[Tag]]: """返回分离好的正反tag元组""" subscribed_tags = [] banned_tags = [] for tag in stored_tags: if tag.startswith("~"): banned_tags.append(tag.lstrip("~")) else: subscribed_tags.append(tag) return subscribed_tags, banned_tags def is_banned_post( self, post_tags: Collection[Tag], subscribed_tags: list[Tag], banned_tags: list[Tag], ) -> bool: """只要存在任意屏蔽tag则返回真,此行为优先级最高。 存在任意被订阅tag则返回假,此行为优先级次之。 若被订阅tag为空,则返回假。 """ # 存在任意需要屏蔽的tag则为真 if banned_tags: for tag in post_tags or []: if tag in banned_tags: return True # 检测屏蔽tag后,再检测订阅tag # 存在任意需要订阅的tag则为假 if subscribed_tags: ban_it = True for tag in post_tags or []: if tag in subscribed_tags: ban_it = False return ban_it else: return False async def filter_user_custom( self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag] ) -> list[RawPost]: res: list[RawPost] = [] for raw_post in raw_post_list: if self.categories: cat = self.get_category(raw_post) if cats and cat not in cats: continue if self.enable_tag and tags: raw_post_tags = self.get_tags(raw_post) if isinstance(raw_post_tags, Collection) and self.is_banned_post( raw_post_tags, *self.tag_separator(tags) ): continue res.append(raw_post) return res async def dispatch_user_post( self, new_posts: list[RawPost], sub_unit: SubUnit ) -> list[tuple[PlatformTarget, list[Post]]]: res: list[tuple[PlatformTarget, list[Post]]] = [] for user, cats, required_tags in sub_unit.user_sub_infos: user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags) user_post: list[Post] = [] for raw_post in user_raw_post: user_post.append(await self.do_parse(raw_post)) res.append((user, user_post)) return res @abstractmethod def get_category(self, post: RawPost) -> Category | None: "Return category of given Rawpost" raise NotImplementedError() class MessageProcess(Platform, abstract=True): "General message process fetch, parse, filter progress" def __init__(self, ctx: ProcessContext): super().__init__(ctx) self.parse_cache: dict[Any, Post] = {} @abstractmethod def get_id(self, post: RawPost) -> Any: "Get post id of given RawPost" async def do_parse(self, raw_post: RawPost) -> Post: post_id = self.get_id(raw_post) if post_id not in self.parse_cache: retry_times = 3 while retry_times: try: self.parse_cache[post_id] = await self.parse(raw_post) break except Exception as err: retry_times -= 1 if not retry_times: raise err return self.parse_cache[post_id] @abstractmethod async def get_sub_list(self, target: Target) -> list[RawPost]: "Get post list of the given target" raise NotImplementedError() @abstractmethod async def batch_get_sub_list(self, targets: list[Target]) -> list[list[RawPost]]: "Get post list of the given targets" raise NotImplementedError() @abstractmethod def get_date(self, post: RawPost) -> int | None: "Get post timestamp and return, return None if can't get the time" async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]: res = [] for raw_post in raw_post_list: # post_id = self.get_id(raw_post) # if post_id in exists_posts_set: # continue if ( (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60 and plugin_config.bison_init_filter ): continue try: self.get_category(raw_post) except CategoryNotSupport as e: logger.info("未支持解析的推文类别:" + repr(e) + ",忽略") continue except CategoryNotRecognize as e: logger.warning("未知推文类别:" + repr(e)) msgs = self.ctx.gen_req_records() for m in msgs: logger.warning(m) continue except NotImplementedError: pass res.append(raw_post) return res class NewMessage(MessageProcess, abstract=True): "Fetch a list of messages, filter the new messages, dispatch it to different users" @dataclass class MessageStorage: inited: bool exists_posts: set[Any] async def filter_common_with_diff(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]: filtered_post = await self.filter_common(raw_post_list) store = self.get_stored_data(target) or self.MessageStorage(False, set()) res = [] if not store.inited and plugin_config.bison_init_filter: # target not init for raw_post in filtered_post: post_id = self.get_id(raw_post) store.exists_posts.add(post_id) logger.info(f"init {self.platform_name}-{target} with {store.exists_posts}") store.inited = True else: for raw_post in filtered_post: post_id = self.get_id(raw_post) if post_id in store.exists_posts: continue res.append(raw_post) store.exists_posts.add(post_id) self.set_stored_data(target, store) logger.trace(f"本次抓取 {len(raw_post_list)} 条,过滤后 {len(filtered_post)} 条,新消息 {len(res)} 条") return res async def _handle_new_post( self, post_list: list[RawPost], sub_unit: SubUnit, ) -> list[tuple[PlatformTarget, list[Post]]]: new_posts = await self.filter_common_with_diff(sub_unit.sub_target, post_list) if not new_posts: return [] else: for post in new_posts: logger.info( "fetch new post from {} {}: {}".format( self.platform_name, sub_unit.sub_target if self.has_target else "-", self.get_id(post), ) ) res = await self.dispatch_user_post(new_posts, sub_unit) self.parse_cache = {} return res async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]: post_list = await self.get_sub_list(sub_unit.sub_target) return await self._handle_new_post(post_list, sub_unit) async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]: if not self.has_target: raise RuntimeError("Target without target should not use batch api") # pragma: no cover posts_set = await self.batch_get_sub_list([x[0] for x in sub_units]) res = [] for sub_unit, posts in zip(sub_units, posts_set): res.extend(await self._handle_new_post(posts, sub_unit)) return res class StatusChange(Platform, abstract=True): "Watch a status, and fire a post when status changes" class FetchError(RuntimeError): pass @abstractmethod async def get_status(self, target: Target) -> Any: ... @abstractmethod async def batch_get_status(self, targets: list[Target]) -> list[Any]: ... @abstractmethod def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]: ... @abstractmethod async def parse(self, raw_post: RawPost) -> Post: ... async def _handle_status_change( self, new_status: Any, sub_unit: SubUnit ) -> list[tuple[PlatformTarget, list[Post]]]: res = [] if old_status := self.get_stored_data(sub_unit.sub_target): diff = self.compare_status(sub_unit.sub_target, old_status, new_status) if diff: logger.info( "status changes {} {}: {} -> {}".format( self.platform_name, sub_unit.sub_target if self.has_target else "-", old_status, new_status, ) ) res = await self.dispatch_user_post(diff, sub_unit) self.set_stored_data(sub_unit.sub_target, new_status) return res async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]: try: new_status = await self.get_status(sub_unit.sub_target) except self.FetchError as err: logger.warning(f"fetching {self.name}-{sub_unit.sub_target} error: {err}") raise return await self._handle_status_change(new_status, sub_unit) async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]: if not self.has_target: raise RuntimeError("Target without target should not use batch api") # pragma: no cover new_statuses = await self.batch_get_status([x[0] for x in sub_units]) res = [] for sub_unit, new_status in zip(sub_units, new_statuses): res.extend(await self._handle_status_change(new_status, sub_unit)) return res class SimplePost(NewMessage, abstract=True): "Fetch a list of messages, dispatch it to different users" async def _handle_new_post( self, post_list: list[RawPost], sub_unit: SubUnit, ) -> list[tuple[PlatformTarget, list[Post]]]: if not post_list: return [] else: for post in post_list: logger.info( "fetch new post from {} {}: {}".format( self.platform_name, sub_unit.sub_target if self.has_target else "-", self.get_id(post), ) ) res = await self.dispatch_user_post(post_list, sub_unit) self.parse_cache = {} return res def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]: if typing.TYPE_CHECKING: class NoTargetGroup(Platform, abstract=True): platform_list: list[type[Platform]] platform_obj_list: list[Platform] DUMMY_STR = "_DUMMY" platform_name = platform_list[0].platform_name name = DUMMY_STR categories_keys = set() categories = {} site = platform_list[0].site for platform in platform_list: if platform.has_target: raise RuntimeError(f"Platform {platform.name} should have no target") if name == DUMMY_STR: name = platform.name elif name != platform.name: raise RuntimeError(f"Platform name for {platform_name} not fit") platform_category_key_set = set(platform.categories.keys()) if platform_category_key_set & categories_keys: raise RuntimeError(f"Platform categories for {platform_name} duplicate") categories_keys |= platform_category_key_set categories.update(platform.categories) if platform.site != site: raise RuntimeError(f"Platform scheduler for {platform_name} not fit") def __init__(self: "NoTargetGroup", ctx: ProcessContext): Platform.__init__(self, ctx) self.platform_obj_list = [] for platform_class in self.platform_list: self.platform_obj_list.append(platform_class(ctx)) def __str__(self: "NoTargetGroup") -> str: return "[" + " ".join(x.name for x in self.platform_list) + "]" @classmethod async def get_target_name(cls, client: AsyncClient, target: Target): return await platform_list[0].get_target_name(client, target) async def fetch_new_post(self: "NoTargetGroup", sub_unit: SubUnit): res = defaultdict(list) for platform in self.platform_obj_list: platform_res = await platform.fetch_new_post(sub_unit) for user, posts in platform_res: res[user].extend(posts) return [[key, val] for key, val in res.items()] return type( "NoTargetGroup", (Platform,), { "platform_list": platform_list, "platform_name": platform_list[0].platform_name, "name": name, "categories": categories, "site": site, "is_common": platform_list[0].is_common, "enabled": True, "has_target": False, "enable_tag": False, "__init__": __init__, "get_target_name": get_target_name, "fetch_new_post": fetch_new_post, }, abstract=True, )