2022-08-04 23:04:53 +08:00

383 lines
12 KiB
Python

import json
import ssl
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Collection, Literal, Optional
import httpx
from nonebot.log import logger
from ..plugin_config import plugin_config
from ..post import Post
from ..types import Category, RawPost, Tag, Target, User, UserSubInfo
class CategoryNotSupport(Exception):
"raise in get_category, when post category is not supported"
class RegistryMeta(type):
def __new__(cls, name, bases, namespace, **kwargs):
return super().__new__(cls, name, bases, namespace)
def __init__(cls, name, bases, namespace, **kwargs):
if kwargs.get("base"):
# this is the base class
cls.registry = []
elif not kwargs.get("abstract"):
# this is the subclass
cls.registry.append(cls)
super().__init__(name, bases, namespace, **kwargs)
class RegistryABCMeta(RegistryMeta, ABC):
...
class Platform(metaclass=RegistryABCMeta, base=True):
scheduler_class: str
is_common: bool
enabled: bool
name: str
has_target: bool
categories: dict[Category, str]
enable_tag: bool
store: dict[Target, Any]
platform_name: str
parse_target_promot: Optional[str] = None
@abstractmethod
async def get_target_name(self, target: Target) -> Optional[str]:
...
@abstractmethod
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
...
async def do_fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
try:
return await self.fetch_new_post(target, users)
except httpx.RequestError as err:
logger.warning(
"network connection error: {}, url: {}".format(
type(err), err.request.url
)
)
return []
except ssl.SSLError as err:
logger.warning(f"ssl error: {err}")
return []
except json.JSONDecodeError as err:
logger.warning(f"json error, parsing: {err.doc}")
return []
@abstractmethod
async def parse(self, raw_post: RawPost) -> Post:
...
async def do_parse(self, raw_post: RawPost) -> Post:
"actually function called"
return await self.parse(raw_post)
def __init__(self):
super().__init__()
self.reverse_category = {}
for key, val in self.categories.items():
self.reverse_category[val] = key
self.store = dict()
class ParseTargetException(Exception):
pass
async def parse_target(self, target_string: str) -> Target:
return Target(target_string)
@abstractmethod
def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]:
"Return Tag list of given RawPost"
def get_stored_data(self, target: Target) -> Any:
return self.store.get(target)
def set_stored_data(self, target: Target, data: Any):
self.store[target] = data
async def filter_user_custom(
self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]
) -> list[RawPost]:
res: list[RawPost] = []
for raw_post in raw_post_list:
if self.categories:
cat = self.get_category(raw_post)
if cats and cat not in cats:
continue
if self.enable_tag and tags:
flag = False
post_tags = self.get_tags(raw_post)
for tag in post_tags or []:
if tag in tags:
flag = True
break
if not flag:
continue
res.append(raw_post)
return res
async def dispatch_user_post(
self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
res: list[tuple[User, list[Post]]] = []
for user, cats, required_tags in users:
user_raw_post = await self.filter_user_custom(
new_posts, cats, required_tags
)
user_post: list[Post] = []
for raw_post in user_raw_post:
user_post.append(await self.do_parse(raw_post))
res.append((user, user_post))
return res
@abstractmethod
def get_category(self, post: RawPost) -> Optional[Category]:
"Return category of given Rawpost"
raise NotImplementedError()
class MessageProcess(Platform, abstract=True):
"General message process fetch, parse, filter progress"
def __init__(self):
super().__init__()
self.parse_cache: dict[Any, Post] = dict()
@abstractmethod
def get_id(self, post: RawPost) -> Any:
"Get post id of given RawPost"
async def do_parse(self, raw_post: RawPost) -> Post:
post_id = self.get_id(raw_post)
if post_id not in self.parse_cache:
retry_times = 3
while retry_times:
try:
self.parse_cache[post_id] = await self.parse(raw_post)
break
except Exception as err:
retry_times -= 1
if not retry_times:
raise err
return self.parse_cache[post_id]
@abstractmethod
async def get_sub_list(self, target: Target) -> list[RawPost]:
"Get post list of the given target"
@abstractmethod
def get_date(self, post: RawPost) -> Optional[int]:
"Get post timestamp and return, return None if can't get the time"
async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]:
res = []
for raw_post in raw_post_list:
# post_id = self.get_id(raw_post)
# if post_id in exists_posts_set:
# continue
if (
(post_time := self.get_date(raw_post))
and time.time() - post_time > 2 * 60 * 60
and plugin_config.bison_init_filter
):
continue
try:
self.get_category(raw_post)
except CategoryNotSupport:
continue
except NotImplementedError:
pass
res.append(raw_post)
return res
class NewMessage(MessageProcess, abstract=True):
"Fetch a list of messages, filter the new messages, dispatch it to different users"
@dataclass
class MessageStorage:
inited: bool
exists_posts: set[Any]
async def filter_common_with_diff(
self, target: Target, raw_post_list: list[RawPost]
) -> list[RawPost]:
filtered_post = await self.filter_common(raw_post_list)
store = self.get_stored_data(target) or self.MessageStorage(False, set())
res = []
if not store.inited and plugin_config.bison_init_filter:
# target not init
for raw_post in filtered_post:
post_id = self.get_id(raw_post)
store.exists_posts.add(post_id)
logger.info(
"init {}-{} with {}".format(
self.platform_name, target, store.exists_posts
)
)
store.inited = True
else:
for raw_post in filtered_post:
post_id = self.get_id(raw_post)
if post_id in store.exists_posts:
continue
res.append(raw_post)
store.exists_posts.add(post_id)
self.set_stored_data(target, store)
return res
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
post_list = await self.get_sub_list(target)
new_posts = await self.filter_common_with_diff(target, post_list)
if not new_posts:
return []
else:
for post in new_posts:
logger.info(
"fetch new post from {} {}: {}".format(
self.platform_name,
target if self.has_target else "-",
self.get_id(post),
)
)
res = await self.dispatch_user_post(target, new_posts, users)
self.parse_cache = {}
return res
class StatusChange(Platform, abstract=True):
"Watch a status, and fire a post when status changes"
class FetchError(RuntimeError):
pass
@abstractmethod
async def get_status(self, target: Target) -> Any:
...
@abstractmethod
def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]:
...
@abstractmethod
async def parse(self, raw_post: RawPost) -> Post:
...
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
try:
new_status = await self.get_status(target)
except self.FetchError as err:
logger.warning(f"fetching {self.name}-{target} error: {err}")
return []
res = []
if old_status := self.get_stored_data(target):
diff = self.compare_status(target, old_status, new_status)
if diff:
logger.info(
"status changes {} {}: {} -> {}".format(
self.platform_name,
target if self.has_target else "-",
old_status,
new_status,
)
)
res = await self.dispatch_user_post(target, diff, users)
self.set_stored_data(target, new_status)
return res
class SimplePost(MessageProcess, abstract=True):
"Fetch a list of messages, dispatch it to different users"
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
new_posts = await self.get_sub_list(target)
if not new_posts:
return []
else:
for post in new_posts:
logger.info(
"fetch new post from {} {}: {}".format(
self.platform_name,
target if self.has_target else "-",
self.get_id(post),
)
)
res = await self.dispatch_user_post(target, new_posts, users)
self.parse_cache = {}
return res
class NoTargetGroup(Platform, abstract=True):
enable_tag = False
DUMMY_STR = "_DUMMY"
enabled = True
has_target = False
def __init__(self, platform_list: list[Platform]):
self.platform_list = platform_list
self.platform_name = platform_list[0].platform_name
name = self.DUMMY_STR
self.categories = {}
categories_keys = set()
self.scheduler_class = platform_list[0].scheduler_class
for platform in platform_list:
if platform.has_target:
raise RuntimeError(
"Platform {} should have no target".format(platform.name)
)
if name == self.DUMMY_STR:
name = platform.name
elif name != platform.name:
raise RuntimeError(
"Platform name for {} not fit".format(self.platform_name)
)
platform_category_key_set = set(platform.categories.keys())
if platform_category_key_set & categories_keys:
raise RuntimeError(
"Platform categories for {} duplicate".format(self.platform_name)
)
categories_keys |= platform_category_key_set
self.categories.update(platform.categories)
if platform.scheduler_class != self.scheduler_class:
raise RuntimeError(
"Platform scheduler for {} not fit".format(self.platform_name)
)
self.name = name
self.is_common = platform_list[0].is_common
super().__init__()
def __str__(self):
return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]"
async def get_target_name(self, _):
return await self.platform_list[0].get_target_name(_)
async def fetch_new_post(self, target, users):
res = defaultdict(list)
for platform in self.platform_list:
platform_res = await platform.fetch_new_post(target=target, users=users)
for user, posts in platform_res:
res[user].extend(posts)
return [[key, val] for key, val in res.items()]