2022-02-12 10:20:02 +08:00

378 lines
13 KiB
Python

import time
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Collection, Literal, Optional
import httpx
from nonebot import logger
from ..plugin_config import plugin_config
from ..post import Post
from ..types import Category, RawPost, Tag, Target, User, UserSubInfo
class CategoryNotSupport(Exception):
"raise in get_category, when post category is not supported"
class RegistryMeta(type):
def __new__(cls, name, bases, namespace, **kwargs):
return super().__new__(cls, name, bases, namespace)
def __init__(cls, name, bases, namespace, **kwargs):
if kwargs.get("base"):
# this is the base class
cls.registry = []
elif not kwargs.get("abstract"):
# this is the subclass
cls.registry.append(cls)
super().__init__(name, bases, namespace, **kwargs)
class RegistryABCMeta(RegistryMeta, ABC):
...
class Platform(metaclass=RegistryABCMeta, base=True):
schedule_type: Literal["date", "interval", "cron"]
schedule_kw: dict
is_common: bool
enabled: bool
name: str
has_target: bool
categories: dict[Category, str]
enable_tag: bool
store: dict[Target, Any]
platform_name: str
@abstractmethod
async def get_target_name(self, target: Target) -> Optional[str]:
...
@abstractmethod
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
...
@abstractmethod
async def parse(self, raw_post: RawPost) -> Post:
...
async def do_parse(self, raw_post: RawPost) -> Post:
"actually function called"
return await self.parse(raw_post)
def __init__(self):
super().__init__()
self.reverse_category = {}
for key, val in self.categories.items():
self.reverse_category[val] = key
self.store = dict()
@abstractmethod
def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]:
"Return Tag list of given RawPost"
def get_stored_data(self, target: Target) -> Any:
return self.store.get(target)
def set_stored_data(self, target: Target, data: Any):
self.store[target] = data
async def filter_user_custom(
self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]
) -> list[RawPost]:
res: list[RawPost] = []
for raw_post in raw_post_list:
if self.categories:
cat = self.get_category(raw_post)
if cats and cat not in cats:
continue
if self.enable_tag and tags:
flag = False
post_tags = self.get_tags(raw_post)
for tag in post_tags or []:
if tag in tags:
flag = True
break
if not flag:
continue
res.append(raw_post)
return res
async def dispatch_user_post(
self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
res: list[tuple[User, list[Post]]] = []
for user, category_getter, tag_getter in users:
required_tags = tag_getter(target) if self.enable_tag else []
cats = category_getter(target)
user_raw_post = await self.filter_user_custom(
new_posts, cats, required_tags
)
user_post: list[Post] = []
for raw_post in user_raw_post:
user_post.append(await self.do_parse(raw_post))
res.append((user, user_post))
return res
@abstractmethod
def get_category(self, post: RawPost) -> Optional[Category]:
"Return category of given Rawpost"
raise NotImplementedError()
class MessageProcess(Platform, abstract=True):
"General message process fetch, parse, filter progress"
def __init__(self):
super().__init__()
self.parse_cache: dict[Any, Post] = dict()
@abstractmethod
def get_id(self, post: RawPost) -> Any:
"Get post id of given RawPost"
async def do_parse(self, raw_post: RawPost) -> Post:
post_id = self.get_id(raw_post)
if post_id not in self.parse_cache:
retry_times = 3
while retry_times:
try:
self.parse_cache[post_id] = await self.parse(raw_post)
break
except Exception as err:
retry_times -= 1
if not retry_times:
raise err
return self.parse_cache[post_id]
@abstractmethod
async def get_sub_list(self, target: Target) -> list[RawPost]:
"Get post list of the given target"
@abstractmethod
def get_date(self, post: RawPost) -> Optional[int]:
"Get post timestamp and return, return None if can't get the time"
async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]:
res = []
for raw_post in raw_post_list:
# post_id = self.get_id(raw_post)
# if post_id in exists_posts_set:
# continue
if (
(post_time := self.get_date(raw_post))
and time.time() - post_time > 2 * 60 * 60
and plugin_config.bison_init_filter
):
continue
try:
self.get_category(raw_post)
except CategoryNotSupport:
continue
except NotImplementedError:
pass
res.append(raw_post)
return res
class NewMessage(MessageProcess, abstract=True):
"Fetch a list of messages, filter the new messages, dispatch it to different users"
@dataclass
class MessageStorage:
inited: bool
exists_posts: set[Any]
async def filter_common_with_diff(
self, target: Target, raw_post_list: list[RawPost]
) -> list[RawPost]:
filtered_post = await self.filter_common(raw_post_list)
store = self.get_stored_data(target) or self.MessageStorage(False, set())
res = []
if not store.inited and plugin_config.bison_init_filter:
# target not init
for raw_post in filtered_post:
post_id = self.get_id(raw_post)
store.exists_posts.add(post_id)
logger.info(
"init {}-{} with {}".format(
self.platform_name, target, store.exists_posts
)
)
store.inited = True
else:
for raw_post in filtered_post:
post_id = self.get_id(raw_post)
if post_id in store.exists_posts:
continue
res.append(raw_post)
store.exists_posts.add(post_id)
self.set_stored_data(target, store)
return res
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
try:
post_list = await self.get_sub_list(target)
new_posts = await self.filter_common_with_diff(target, post_list)
if not new_posts:
return []
else:
for post in new_posts:
logger.info(
"fetch new post from {} {}: {}".format(
self.platform_name,
target if self.has_target else "-",
self.get_id(post),
)
)
res = await self.dispatch_user_post(target, new_posts, users)
self.parse_cache = {}
return res
except httpx.RequestError as err:
logger.warning(
"network connection error: {}, url: {}".format(
type(err), err.request.url
)
)
return []
class StatusChange(Platform, abstract=True):
"Watch a status, and fire a post when status changes"
@abstractmethod
async def get_status(self, target: Target) -> Any:
...
@abstractmethod
def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]:
...
@abstractmethod
async def parse(self, raw_post: RawPost) -> Post:
...
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
try:
new_status = await self.get_status(target)
res = []
if old_status := self.get_stored_data(target):
diff = self.compare_status(target, old_status, new_status)
if diff:
logger.info(
"status changes {} {}: {} -> {}".format(
self.platform_name,
target if self.has_target else "-",
old_status,
new_status,
)
)
res = await self.dispatch_user_post(target, diff, users)
self.set_stored_data(target, new_status)
return res
except httpx.RequestError as err:
logger.warning(
"network connection error: {}, url: {}".format(
type(err), err.request.url
)
)
return []
class SimplePost(MessageProcess, abstract=True):
"Fetch a list of messages, dispatch it to different users"
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
try:
new_posts = await self.get_sub_list(target)
if not new_posts:
return []
else:
for post in new_posts:
logger.info(
"fetch new post from {} {}: {}".format(
self.platform_name,
target if self.has_target else "-",
self.get_id(post),
)
)
res = await self.dispatch_user_post(target, new_posts, users)
self.parse_cache = {}
return res
except httpx.RequestError as err:
logger.warning(
"network connection error: {}, url: {}".format(
type(err), err.request.url
)
)
return []
class NoTargetGroup(Platform, abstract=True):
enable_tag = False
DUMMY_STR = "_DUMMY"
enabled = True
has_target = False
def __init__(self, platform_list: list[Platform]):
self.platform_list = platform_list
name = self.DUMMY_STR
self.categories = {}
categories_keys = set()
self.schedule_type = platform_list[0].schedule_type
self.schedule_kw = platform_list[0].schedule_kw
for platform in platform_list:
if platform.has_target:
raise RuntimeError(
"Platform {} should have no target".format(platform.name)
)
if name == self.DUMMY_STR:
name = platform.name
elif name != platform.name:
raise RuntimeError(
"Platform name for {} not fit".format(self.platform_name)
)
platform_category_key_set = set(platform.categories.keys())
if platform_category_key_set & categories_keys:
raise RuntimeError(
"Platform categories for {} duplicate".format(self.platform_name)
)
categories_keys |= platform_category_key_set
self.categories.update(platform.categories)
if (
platform.schedule_kw != self.schedule_kw
or platform.schedule_type != self.schedule_type
):
raise RuntimeError(
"Platform scheduler for {} not fit".format(self.platform_name)
)
self.name = name
self.is_common = platform_list[0].is_common
super().__init__()
def __str__(self):
return "[" + " ".join(map(lambda x: x.name, self.platform_list)) + "]"
async def get_target_name(self, _):
return await self.platform_list[0].get_target_name(_)
async def fetch_new_post(self, target, users):
res = defaultdict(list)
for platform in self.platform_list:
platform_res = await platform.fetch_new_post(target=target, users=users)
for user, posts in platform_res:
res[user].extend(posts)
return [[key, val] for key, val in res.items()]