mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-06 20:06:12 +08:00
383 lines
12 KiB
Python
383 lines
12 KiB
Python
import json
|
|
import ssl
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import Any, Collection, Literal, Optional
|
|
|
|
import httpx
|
|
from nonebot.log import logger
|
|
|
|
from ..plugin_config import plugin_config
|
|
from ..post import Post
|
|
from ..types import Category, RawPost, Tag, Target, User, UserSubInfo
|
|
|
|
|
|
class CategoryNotSupport(Exception):
|
|
"raise in get_category, when post category is not supported"
|
|
|
|
|
|
class RegistryMeta(type):
|
|
def __new__(cls, name, bases, namespace, **kwargs):
|
|
return super().__new__(cls, name, bases, namespace)
|
|
|
|
def __init__(cls, name, bases, namespace, **kwargs):
|
|
if kwargs.get("base"):
|
|
# this is the base class
|
|
cls.registry = []
|
|
elif not kwargs.get("abstract"):
|
|
# this is the subclass
|
|
cls.registry.append(cls)
|
|
|
|
super().__init__(name, bases, namespace, **kwargs)
|
|
|
|
|
|
class RegistryABCMeta(RegistryMeta, ABC):
|
|
...
|
|
|
|
|
|
class Platform(metaclass=RegistryABCMeta, base=True):
|
|
|
|
scheduler_class: str
|
|
is_common: bool
|
|
enabled: bool
|
|
name: str
|
|
has_target: bool
|
|
categories: dict[Category, str]
|
|
enable_tag: bool
|
|
store: dict[Target, Any]
|
|
platform_name: str
|
|
parse_target_promot: Optional[str] = None
|
|
|
|
@abstractmethod
|
|
async def get_target_name(self, target: Target) -> Optional[str]:
|
|
...
|
|
|
|
@abstractmethod
|
|
async def fetch_new_post(
|
|
self, target: Target, users: list[UserSubInfo]
|
|
) -> list[tuple[User, list[Post]]]:
|
|
...
|
|
|
|
async def do_fetch_new_post(
|
|
self, target: Target, users: list[UserSubInfo]
|
|
) -> list[tuple[User, list[Post]]]:
|
|
try:
|
|
return await self.fetch_new_post(target, users)
|
|
except httpx.RequestError as err:
|
|
logger.warning(
|
|
"network connection error: {}, url: {}".format(
|
|
type(err), err.request.url
|
|
)
|
|
)
|
|
return []
|
|
except ssl.SSLError as err:
|
|
logger.warning(f"ssl error: {err}")
|
|
return []
|
|
except json.JSONDecodeError as err:
|
|
logger.warning(f"json error, parsing: {err.doc}")
|
|
return []
|
|
|
|
@abstractmethod
|
|
async def parse(self, raw_post: RawPost) -> Post:
|
|
...
|
|
|
|
async def do_parse(self, raw_post: RawPost) -> Post:
|
|
"actually function called"
|
|
return await self.parse(raw_post)
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.reverse_category = {}
|
|
for key, val in self.categories.items():
|
|
self.reverse_category[val] = key
|
|
self.store = dict()
|
|
|
|
class ParseTargetException(Exception):
|
|
pass
|
|
|
|
async def parse_target(self, target_string: str) -> Target:
|
|
return Target(target_string)
|
|
|
|
@abstractmethod
|
|
def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]:
|
|
"Return Tag list of given RawPost"
|
|
|
|
def get_stored_data(self, target: Target) -> Any:
|
|
return self.store.get(target)
|
|
|
|
def set_stored_data(self, target: Target, data: Any):
|
|
self.store[target] = data
|
|
|
|
async def filter_user_custom(
|
|
self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]
|
|
) -> list[RawPost]:
|
|
res: list[RawPost] = []
|
|
for raw_post in raw_post_list:
|
|
if self.categories:
|
|
cat = self.get_category(raw_post)
|
|
if cats and cat not in cats:
|
|
continue
|
|
if self.enable_tag and tags:
|
|
flag = False
|
|
post_tags = self.get_tags(raw_post)
|
|
for tag in post_tags or []:
|
|
if tag in tags:
|
|
flag = True
|
|
break
|
|
if not flag:
|
|
continue
|
|
res.append(raw_post)
|
|
return res
|
|
|
|
async def dispatch_user_post(
|
|
self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]
|
|
) -> list[tuple[User, list[Post]]]:
|
|
res: list[tuple[User, list[Post]]] = []
|
|
for user, cats, required_tags in users:
|
|
user_raw_post = await self.filter_user_custom(
|
|
new_posts, cats, required_tags
|
|
)
|
|
user_post: list[Post] = []
|
|
for raw_post in user_raw_post:
|
|
user_post.append(await self.do_parse(raw_post))
|
|
res.append((user, user_post))
|
|
return res
|
|
|
|
@abstractmethod
|
|
def get_category(self, post: RawPost) -> Optional[Category]:
|
|
"Return category of given Rawpost"
|
|
raise NotImplementedError()
|
|
|
|
|
|
class MessageProcess(Platform, abstract=True):
|
|
"General message process fetch, parse, filter progress"
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.parse_cache: dict[Any, Post] = dict()
|
|
|
|
@abstractmethod
|
|
def get_id(self, post: RawPost) -> Any:
|
|
"Get post id of given RawPost"
|
|
|
|
async def do_parse(self, raw_post: RawPost) -> Post:
|
|
post_id = self.get_id(raw_post)
|
|
if post_id not in self.parse_cache:
|
|
retry_times = 3
|
|
while retry_times:
|
|
try:
|
|
self.parse_cache[post_id] = await self.parse(raw_post)
|
|
break
|
|
except Exception as err:
|
|
retry_times -= 1
|
|
if not retry_times:
|
|
raise err
|
|
return self.parse_cache[post_id]
|
|
|
|
@abstractmethod
|
|
async def get_sub_list(self, target: Target) -> list[RawPost]:
|
|
"Get post list of the given target"
|
|
|
|
@abstractmethod
|
|
def get_date(self, post: RawPost) -> Optional[int]:
|
|
"Get post timestamp and return, return None if can't get the time"
|
|
|
|
async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]:
|
|
res = []
|
|
for raw_post in raw_post_list:
|
|
# post_id = self.get_id(raw_post)
|
|
# if post_id in exists_posts_set:
|
|
# continue
|
|
if (
|
|
(post_time := self.get_date(raw_post))
|
|
and time.time() - post_time > 2 * 60 * 60
|
|
and plugin_config.bison_init_filter
|
|
):
|
|
continue
|
|
try:
|
|
self.get_category(raw_post)
|
|
except CategoryNotSupport:
|
|
continue
|
|
except NotImplementedError:
|
|
pass
|
|
res.append(raw_post)
|
|
return res
|
|
|
|
|
|
class NewMessage(MessageProcess, abstract=True):
|
|
"Fetch a list of messages, filter the new messages, dispatch it to different users"
|
|
|
|
@dataclass
|
|
class MessageStorage:
|
|
inited: bool
|
|
exists_posts: set[Any]
|
|
|
|
async def filter_common_with_diff(
|
|
self, target: Target, raw_post_list: list[RawPost]
|
|
) -> list[RawPost]:
|
|
filtered_post = await self.filter_common(raw_post_list)
|
|
store = self.get_stored_data(target) or self.MessageStorage(False, set())
|
|
res = []
|
|
if not store.inited and plugin_config.bison_init_filter:
|
|
# target not init
|
|
for raw_post in filtered_post:
|
|
post_id = self.get_id(raw_post)
|
|
store.exists_posts.add(post_id)
|
|
logger.info(
|
|
"init {}-{} with {}".format(
|
|
self.platform_name, target, store.exists_posts
|
|
)
|
|
)
|
|
store.inited = True
|
|
else:
|
|
for raw_post in filtered_post:
|
|
post_id = self.get_id(raw_post)
|
|
if post_id in store.exists_posts:
|
|
continue
|
|
res.append(raw_post)
|
|
store.exists_posts.add(post_id)
|
|
self.set_stored_data(target, store)
|
|
return res
|
|
|
|
async def fetch_new_post(
|
|
self, target: Target, users: list[UserSubInfo]
|
|
) -> list[tuple[User, list[Post]]]:
|
|
post_list = await self.get_sub_list(target)
|
|
new_posts = await self.filter_common_with_diff(target, post_list)
|
|
if not new_posts:
|
|
return []
|
|
else:
|
|
for post in new_posts:
|
|
logger.info(
|
|
"fetch new post from {} {}: {}".format(
|
|
self.platform_name,
|
|
target if self.has_target else "-",
|
|
self.get_id(post),
|
|
)
|
|
)
|
|
res = await self.dispatch_user_post(target, new_posts, users)
|
|
self.parse_cache = {}
|
|
return res
|
|
|
|
|
|
class StatusChange(Platform, abstract=True):
|
|
"Watch a status, and fire a post when status changes"
|
|
|
|
class FetchError(RuntimeError):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_status(self, target: Target) -> Any:
|
|
...
|
|
|
|
@abstractmethod
|
|
def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]:
|
|
...
|
|
|
|
@abstractmethod
|
|
async def parse(self, raw_post: RawPost) -> Post:
|
|
...
|
|
|
|
async def fetch_new_post(
|
|
self, target: Target, users: list[UserSubInfo]
|
|
) -> list[tuple[User, list[Post]]]:
|
|
try:
|
|
new_status = await self.get_status(target)
|
|
except self.FetchError as err:
|
|
logger.warning(f"fetching {self.name}-{target} error: {err}")
|
|
return []
|
|
res = []
|
|
if old_status := self.get_stored_data(target):
|
|
diff = self.compare_status(target, old_status, new_status)
|
|
if diff:
|
|
logger.info(
|
|
"status changes {} {}: {} -> {}".format(
|
|
self.platform_name,
|
|
target if self.has_target else "-",
|
|
old_status,
|
|
new_status,
|
|
)
|
|
)
|
|
res = await self.dispatch_user_post(target, diff, users)
|
|
self.set_stored_data(target, new_status)
|
|
return res
|
|
|
|
|
|
class SimplePost(MessageProcess, abstract=True):
|
|
"Fetch a list of messages, dispatch it to different users"
|
|
|
|
async def fetch_new_post(
|
|
self, target: Target, users: list[UserSubInfo]
|
|
) -> list[tuple[User, list[Post]]]:
|
|
new_posts = await self.get_sub_list(target)
|
|
if not new_posts:
|
|
return []
|
|
else:
|
|
for post in new_posts:
|
|
logger.info(
|
|
"fetch new post from {} {}: {}".format(
|
|
self.platform_name,
|
|
target if self.has_target else "-",
|
|
self.get_id(post),
|
|
)
|
|
)
|
|
res = await self.dispatch_user_post(target, new_posts, users)
|
|
self.parse_cache = {}
|
|
return res
|
|
|
|
|
|
class NoTargetGroup(Platform, abstract=True):
|
|
enable_tag = False
|
|
DUMMY_STR = "_DUMMY"
|
|
enabled = True
|
|
has_target = False
|
|
|
|
def __init__(self, platform_list: list[Platform]):
|
|
self.platform_list = platform_list
|
|
self.platform_name = platform_list[0].platform_name
|
|
name = self.DUMMY_STR
|
|
self.categories = {}
|
|
categories_keys = set()
|
|
self.scheduler_class = platform_list[0].scheduler_class
|
|
for platform in platform_list:
|
|
if platform.has_target:
|
|
raise RuntimeError(
|
|
"Platform {} should have no target".format(platform.name)
|
|
)
|
|
if name == self.DUMMY_STR:
|
|
name = platform.name
|
|
elif name != platform.name:
|
|
raise RuntimeError(
|
|
"Platform name for {} not fit".format(self.platform_name)
|
|
)
|
|
platform_category_key_set = set(platform.categories.keys())
|
|
if platform_category_key_set & categories_keys:
|
|
raise RuntimeError(
|
|
"Platform categories for {} duplicate".format(self.platform_name)
|
|
)
|
|
categories_keys |= platform_category_key_set
|
|
self.categories.update(platform.categories)
|
|
if platform.scheduler_class != self.scheduler_class:
|
|
raise RuntimeError(
|
|
"Platform scheduler for {} not fit".format(self.platform_name)
|
|
)
|
|
self.name = name
|
|
self.is_common = platform_list[0].is_common
|
|
super().__init__()
|
|
|
|
def __str__(self):
|
|
return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]"
|
|
|
|
async def get_target_name(self, _):
|
|
return await self.platform_list[0].get_target_name(_)
|
|
|
|
async def fetch_new_post(self, target, users):
|
|
res = defaultdict(list)
|
|
for platform in self.platform_list:
|
|
platform_res = await platform.fetch_new_post(target=target, users=users)
|
|
for user, posts in platform_res:
|
|
res[user].extend(posts)
|
|
return [[key, val] for key, val in res.items()]
|