import ssl import time from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass from typing import Any, Collection, Literal, Optional import httpx from nonebot.log import logger from ..plugin_config import plugin_config from ..post import Post from ..types import Category, RawPost, Tag, Target, User, UserSubInfo class CategoryNotSupport(Exception): "raise in get_category, when post category is not supported" class RegistryMeta(type): def __new__(cls, name, bases, namespace, **kwargs): return super().__new__(cls, name, bases, namespace) def __init__(cls, name, bases, namespace, **kwargs): if kwargs.get("base"): # this is the base class cls.registry = [] elif not kwargs.get("abstract"): # this is the subclass cls.registry.append(cls) super().__init__(name, bases, namespace, **kwargs) class RegistryABCMeta(RegistryMeta, ABC): ... class Platform(metaclass=RegistryABCMeta, base=True): schedule_type: Literal["date", "interval", "cron"] schedule_kw: dict is_common: bool enabled: bool name: str has_target: bool categories: dict[Category, str] enable_tag: bool store: dict[Target, Any] platform_name: str parse_target_promot: Optional[str] = None @abstractmethod async def get_target_name(self, target: Target) -> Optional[str]: ... @abstractmethod async def fetch_new_post( self, target: Target, users: list[UserSubInfo] ) -> list[tuple[User, list[Post]]]: ... @abstractmethod async def parse(self, raw_post: RawPost) -> Post: ... async def do_parse(self, raw_post: RawPost) -> Post: "actually function called" return await self.parse(raw_post) def __init__(self): super().__init__() self.reverse_category = {} for key, val in self.categories.items(): self.reverse_category[val] = key self.store = dict() class ParseTargetException(Exception): pass async def parse_target(self, target_string: str) -> Target: return Target(target_string) @abstractmethod def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]: "Return Tag list of given RawPost" def get_stored_data(self, target: Target) -> Any: return self.store.get(target) def set_stored_data(self, target: Target, data: Any): self.store[target] = data async def filter_user_custom( self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag] ) -> list[RawPost]: res: list[RawPost] = [] for raw_post in raw_post_list: if self.categories: cat = self.get_category(raw_post) if cats and cat not in cats: continue if self.enable_tag and tags: flag = False post_tags = self.get_tags(raw_post) for tag in post_tags or []: if tag in tags: flag = True break if not flag: continue res.append(raw_post) return res async def dispatch_user_post( self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo] ) -> list[tuple[User, list[Post]]]: res: list[tuple[User, list[Post]]] = [] for user, category_getter, tag_getter in users: required_tags = tag_getter(target) if self.enable_tag else [] cats = category_getter(target) user_raw_post = await self.filter_user_custom( new_posts, cats, required_tags ) user_post: list[Post] = [] for raw_post in user_raw_post: user_post.append(await self.do_parse(raw_post)) res.append((user, user_post)) return res @abstractmethod def get_category(self, post: RawPost) -> Optional[Category]: "Return category of given Rawpost" raise NotImplementedError() class MessageProcess(Platform, abstract=True): "General message process fetch, parse, filter progress" def __init__(self): super().__init__() self.parse_cache: dict[Any, Post] = dict() @abstractmethod def get_id(self, post: RawPost) -> Any: "Get post id of given RawPost" async def do_parse(self, raw_post: RawPost) -> Post: post_id = self.get_id(raw_post) if post_id not in self.parse_cache: retry_times = 3 while retry_times: try: self.parse_cache[post_id] = await self.parse(raw_post) break except Exception as err: retry_times -= 1 if not retry_times: raise err return self.parse_cache[post_id] @abstractmethod async def get_sub_list(self, target: Target) -> list[RawPost]: "Get post list of the given target" @abstractmethod def get_date(self, post: RawPost) -> Optional[int]: "Get post timestamp and return, return None if can't get the time" async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]: res = [] for raw_post in raw_post_list: # post_id = self.get_id(raw_post) # if post_id in exists_posts_set: # continue if ( (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60 and plugin_config.bison_init_filter ): continue try: self.get_category(raw_post) except CategoryNotSupport: continue except NotImplementedError: pass res.append(raw_post) return res class NewMessage(MessageProcess, abstract=True): "Fetch a list of messages, filter the new messages, dispatch it to different users" @dataclass class MessageStorage: inited: bool exists_posts: set[Any] async def filter_common_with_diff( self, target: Target, raw_post_list: list[RawPost] ) -> list[RawPost]: filtered_post = await self.filter_common(raw_post_list) store = self.get_stored_data(target) or self.MessageStorage(False, set()) res = [] if not store.inited and plugin_config.bison_init_filter: # target not init for raw_post in filtered_post: post_id = self.get_id(raw_post) store.exists_posts.add(post_id) logger.info( "init {}-{} with {}".format( self.platform_name, target, store.exists_posts ) ) store.inited = True else: for raw_post in filtered_post: post_id = self.get_id(raw_post) if post_id in store.exists_posts: continue res.append(raw_post) store.exists_posts.add(post_id) self.set_stored_data(target, store) return res async def fetch_new_post( self, target: Target, users: list[UserSubInfo] ) -> list[tuple[User, list[Post]]]: try: post_list = await self.get_sub_list(target) new_posts = await self.filter_common_with_diff(target, post_list) if not new_posts: return [] else: for post in new_posts: logger.info( "fetch new post from {} {}: {}".format( self.platform_name, target if self.has_target else "-", self.get_id(post), ) ) res = await self.dispatch_user_post(target, new_posts, users) self.parse_cache = {} return res except httpx.RequestError as err: logger.warning( "network connection error: {}, url: {}".format( type(err), err.request.url ) ) return [] except ssl.SSLError as err: logger.warning(f"ssl error: {err}") return [] class StatusChange(Platform, abstract=True): "Watch a status, and fire a post when status changes" @abstractmethod async def get_status(self, target: Target) -> Any: ... @abstractmethod def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]: ... @abstractmethod async def parse(self, raw_post: RawPost) -> Post: ... async def fetch_new_post( self, target: Target, users: list[UserSubInfo] ) -> list[tuple[User, list[Post]]]: try: new_status = await self.get_status(target) res = [] if old_status := self.get_stored_data(target): diff = self.compare_status(target, old_status, new_status) if diff: logger.info( "status changes {} {}: {} -> {}".format( self.platform_name, target if self.has_target else "-", old_status, new_status, ) ) res = await self.dispatch_user_post(target, diff, users) self.set_stored_data(target, new_status) return res except httpx.RequestError as err: logger.warning( "network connection error: {}, url: {}".format( type(err), err.request.url ) ) return [] except ssl.SSLError as err: logger.warning(f"ssl error: {err}") return [] class SimplePost(MessageProcess, abstract=True): "Fetch a list of messages, dispatch it to different users" async def fetch_new_post( self, target: Target, users: list[UserSubInfo] ) -> list[tuple[User, list[Post]]]: try: new_posts = await self.get_sub_list(target) if not new_posts: return [] else: for post in new_posts: logger.info( "fetch new post from {} {}: {}".format( self.platform_name, target if self.has_target else "-", self.get_id(post), ) ) res = await self.dispatch_user_post(target, new_posts, users) self.parse_cache = {} return res except httpx.RequestError as err: logger.warning( "network connection error: {}, url: {}".format( type(err), err.request.url ) ) return [] except ssl.SSLError as err: logger.warning(f"ssl error: {err}") return [] class NoTargetGroup(Platform, abstract=True): enable_tag = False DUMMY_STR = "_DUMMY" enabled = True has_target = False def __init__(self, platform_list: list[Platform]): self.platform_list = platform_list name = self.DUMMY_STR self.categories = {} categories_keys = set() self.schedule_type = platform_list[0].schedule_type self.schedule_kw = platform_list[0].schedule_kw for platform in platform_list: if platform.has_target: raise RuntimeError( "Platform {} should have no target".format(platform.name) ) if name == self.DUMMY_STR: name = platform.name elif name != platform.name: raise RuntimeError( "Platform name for {} not fit".format(self.platform_name) ) platform_category_key_set = set(platform.categories.keys()) if platform_category_key_set & categories_keys: raise RuntimeError( "Platform categories for {} duplicate".format(self.platform_name) ) categories_keys |= platform_category_key_set self.categories.update(platform.categories) if ( platform.schedule_kw != self.schedule_kw or platform.schedule_type != self.schedule_type ): raise RuntimeError( "Platform scheduler for {} not fit".format(self.platform_name) ) self.name = name self.is_common = platform_list[0].is_common super().__init__() def __str__(self): return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]" async def get_target_name(self, _): return await self.platform_list[0].get_target_name(_) async def fetch_new_post(self, target, users): res = defaultdict(list) for platform in self.platform_list: platform_res = await platform.fetch_new_post(target=target, users=users) for user, posts in platform_res: res[user].extend(posts) return [[key, val] for key, val in res.items()]