diff --git a/src/plugins/nonebot_hk_reporter/platform/platform.py b/src/plugins/nonebot_hk_reporter/platform/platform.py index e1b30e0..104568b 100644 --- a/src/plugins/nonebot_hk_reporter/platform/platform.py +++ b/src/plugins/nonebot_hk_reporter/platform/platform.py @@ -6,6 +6,7 @@ import httpx from nonebot import logger from ..config import Config +from ..plugin_config import plugin_config from ..post import Post from ..types import Category, RawPost, Tag, Target, User from ..utils import Singleton @@ -22,166 +23,16 @@ class PlatformProto(metaclass=Singleton): has_target: bool platform_name: str enable_tag: bool - async def fetch_new_post(self, target: Target, users: list[User]) -> list[tuple[User, list[Post]]]: - ... - @staticmethod - async def get_account_name(target: Target) -> Optional[str]: - ... - -class Platform(PlatformProto): - "platform with target(account), like weibo, bilibili" - - categories: dict[Category, str] - has_target: bool = True - platform_name: str - enable_tag: bool - - def __init__(self): - self.exists_posts = defaultdict(set) - self.inited = dict() - self.reverse_category = {} - self.cache: dict[Any, Post] = {} - for key, val in self.categories.items(): - self.reverse_category[val] = key - - @staticmethod - async def get_account_name(target: Target) -> Optional[str]: - "Given a tareget, return the username(name) of the target" - raise NotImplementedError() - - async def get_sub_list(self, target: Target) -> list[RawPost]: - "Get post list of the given target" - raise NotImplementedError() - - def get_id(self, post: RawPost) -> Any: - "Get post id of given RawPost" - raise NotImplementedError() - - def get_date(self, post: RawPost) -> Optional[int]: - "Get post timestamp and return, return None if can't get the time" - raise NotImplementedError() - - def get_category(self, post: RawPost) -> Optional[Category]: - "Return category of given Rawpost" - raise NotImplementedError() - - def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]: - "Return Tag list of given RawPost" - raise NotImplementedError() - - async def parse(self, raw_post: RawPost) -> Post: - "parse RawPost into post" - raise NotImplementedError() - - def filter_platform_custom(self, post: RawPost) -> bool: - raise NotImplementedError() - - async def _parse_with_cache(self, post: RawPost) -> Post: - post_id = self.get_id(post) - if post_id not in self.cache: - self.cache[post_id] = await self.parse(post) - return self.cache[post_id] - - async def filter_common(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]: - if not self.inited.get(target, False): - # target not init - for raw_post in raw_post_list: - post_id = self.get_id(raw_post) - self.exists_posts[target].add(post_id) - logger.info('init {}-{} with {}'.format(self.platform_name, target, self.exists_posts[target])) - self.inited[target] = True - return [] - res: list[RawPost] = [] - for raw_post in raw_post_list: - post_id = self.get_id(raw_post) - if post_id in self.exists_posts[target]: - continue - if (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60: - continue - try: - if not self.filter_platform_custom(raw_post): - continue - except NotImplementedError: - pass - try: - self.get_category(raw_post) - except CategoryNotSupport: - continue - except NotImplementedError: - pass - res.append(raw_post) - self.exists_posts[target].add(post_id) - return res - - async def filter_user_custom(self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]) -> list[RawPost]: - res: list[RawPost] = [] - for raw_post in raw_post_list: - if self.categories: - cat = self.get_category(raw_post) - if cats and cat not in cats: - continue - if self.enable_tag and tags: - flag = False - post_tags = self.get_tags(raw_post) - for tag in post_tags: - if tag in tags: - flag = True - break - if not flag: - continue - res.append(raw_post) - return res + cache: dict[Any, Post] async def fetch_new_post(self, target: Target, users: list[User]) -> list[tuple[User, list[Post]]]: - try: - config = Config() - post_list = await self.get_sub_list(target) - new_posts = await self.filter_common(target, post_list) - res: list[tuple[User, list[Post]]] = [] - if not new_posts: - return [] - else: - for post in new_posts: - logger.info('fetch new post from {} {}: {}'.format(self.platform_name, target, self.get_id(post))) - for user in users: - required_tags = config.get_sub_tags(self.platform_name, target, user.user_type, user.user) if self.enable_tag else [] - cats = config.get_sub_category(self.platform_name, target, user.user_type, user.user) - user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags) - user_post: list[Post] = [] - for raw_post in user_raw_post: - user_post.append(await self._parse_with_cache(raw_post)) - res.append((user, user_post)) - self.cache = {} - return res - except httpx.RequestError as err: - logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) - return [] - - -class PlatformNoTarget(PlatformProto): - - categories: dict[Category, str] - has_target = False - platform_name: str - enable_tag: bool - - def __init__(self): - self.exists_posts = set() - self.inited = False - self.reverse_category = {} - self.cache: dict[Any, Post] = {} - for key, val in self.categories.items(): - self.reverse_category[val] = key + ... @staticmethod async def get_account_name(target: Target) -> Optional[str]: "return the username(name) of the target" raise NotImplementedError() - async def get_sub_list(self) -> list[RawPost]: - "Get post list of the given target" - raise NotImplementedError() - def get_id(self, post: RawPost) -> Any: "Get post id of given RawPost" raise NotImplementedError() @@ -219,21 +70,14 @@ class PlatformNoTarget(PlatformProto): retry_times -= 1 return self.cache[post_id] - async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]: - if not self.inited: - # target not init - for raw_post in raw_post_list: - post_id = self.get_id(raw_post) - self.exists_posts.add(post_id) - logger.info('init {} with {}'.format(self.platform_name, self.exists_posts)) - self.inited = True - return [] - res: list[RawPost] = [] + def _do_filter_common(self, raw_post_list: list[RawPost], exists_posts_set: set) -> list[RawPost]: + res = [] for raw_post in raw_post_list: post_id = self.get_id(raw_post) - if post_id in self.exists_posts: + if post_id in exists_posts_set: continue - if (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60: + if (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60 and \ + plugin_config.hk_reporter_init_filter: continue try: if not self.filter_platform_custom(raw_post): @@ -247,7 +91,7 @@ class PlatformNoTarget(PlatformProto): except NotImplementedError: pass res.append(raw_post) - self.exists_posts.add(post_id) + exists_posts_set.add(post_id) return res async def filter_user_custom(self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]) -> list[RawPost]: @@ -269,6 +113,94 @@ class PlatformNoTarget(PlatformProto): res.append(raw_post) return res + +class Platform(PlatformProto): + "platform with target(account), like weibo, bilibili" + + categories: dict[Category, str] + has_target: bool = True + platform_name: str + enable_tag: bool + + def __init__(self): + self.exists_posts = defaultdict(set) + self.inited = dict() + self.reverse_category = {} + self.cache: dict[Any, Post] = {} + for key, val in self.categories.items(): + self.reverse_category[val] = key + + async def get_sub_list(self, target: Target) -> list[RawPost]: + "Get post list of the given target" + raise NotImplementedError() + + async def filter_common(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]: + if not self.inited.get(target, False) and plugin_config.hk_reporter_init_filter: + # target not init + for raw_post in raw_post_list: + post_id = self.get_id(raw_post) + self.exists_posts[target].add(post_id) + logger.info('init {}-{} with {}'.format(self.platform_name, target, self.exists_posts[target])) + self.inited[target] = True + return [] + return self._do_filter_common(raw_post_list, self.exists_posts[target]) + + async def fetch_new_post(self, target: Target, users: list[User]) -> list[tuple[User, list[Post]]]: + try: + config = Config() + post_list = await self.get_sub_list(target) + new_posts = await self.filter_common(target, post_list) + res: list[tuple[User, list[Post]]] = [] + if not new_posts: + return [] + else: + for post in new_posts: + logger.info('fetch new post from {} {}: {}'.format(self.platform_name, target, self.get_id(post))) + for user in users: + required_tags = config.get_sub_tags(self.platform_name, target, user.user_type, user.user) if self.enable_tag else [] + cats = config.get_sub_category(self.platform_name, target, user.user_type, user.user) + user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags) + user_post: list[Post] = [] + for raw_post in user_raw_post: + user_post.append(await self._parse_with_cache(raw_post)) + res.append((user, user_post)) + self.cache = {} + return res + except httpx.RequestError as err: + logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url)) + return [] + + +class PlatformNoTarget(PlatformProto): + + categories: dict[Category, str] + has_target = False + platform_name: str + enable_tag: bool + + async def get_sub_list(self) -> list[RawPost]: + "Get post list of the given target" + raise NotImplementedError() + + def __init__(self): + self.exists_posts = set() + self.inited = False + self.reverse_category = {} + self.cache: dict[Any, Post] = {} + for key, val in self.categories.items(): + self.reverse_category[val] = key + + async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]: + if not self.inited and plugin_config.hk_reporter_init_filter: + # target not init + for raw_post in raw_post_list: + post_id = self.get_id(raw_post) + self.exists_posts.add(post_id) + logger.info('init {} with {}'.format(self.platform_name, self.exists_posts)) + self.inited = True + return [] + return self._do_filter_common(raw_post_list, self.exists_posts) + async def fetch_new_post(self, _: Target, users: list[User]) -> list[tuple[User, list[Post]]]: try: config = Config() diff --git a/src/plugins/nonebot_hk_reporter/plugin_config.py b/src/plugins/nonebot_hk_reporter/plugin_config.py index 83554ff..ca9749c 100644 --- a/src/plugins/nonebot_hk_reporter/plugin_config.py +++ b/src/plugins/nonebot_hk_reporter/plugin_config.py @@ -6,6 +6,7 @@ class PlugConfig(BaseSettings): hk_reporter_config_path: str = "" hk_reporter_use_pic: bool = False hk_reporter_use_local: bool = False + hk_reporter_init_filter: bool = True class Config: extra = 'ignore'