mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-07-09 15:33:01 +08:00
286 lines
11 KiB
Python
286 lines
11 KiB
Python
import time
|
|
from collections import defaultdict
|
|
from typing import Any, Optional
|
|
|
|
import httpx
|
|
from nonebot import logger
|
|
|
|
from ..config import Config
|
|
from ..post import Post
|
|
from ..types import Category, RawPost, Tag, Target, User
|
|
from ..utils import Singleton
|
|
|
|
|
|
class CategoryNotSupport(Exception):
|
|
"raise in get_category, when post category is not supported"
|
|
pass
|
|
|
|
class PlatformProto(metaclass=Singleton):
|
|
|
|
categories: dict[Category, str]
|
|
reverse_category: dict[str, Category]
|
|
has_target: bool
|
|
platform_name: str
|
|
enable_tag: bool
|
|
async def fetch_new_post(self, target: Target, users: list[User]) -> list[tuple[User, list[Post]]]:
|
|
...
|
|
@staticmethod
|
|
async def get_account_name(target: Target) -> Optional[str]:
|
|
...
|
|
|
|
class Platform(PlatformProto):
|
|
"platform with target(account), like weibo, bilibili"
|
|
|
|
categories: dict[Category, str]
|
|
has_target: bool = True
|
|
platform_name: str
|
|
enable_tag: bool
|
|
|
|
def __init__(self):
|
|
self.exists_posts = defaultdict(set)
|
|
self.inited = dict()
|
|
self.reverse_category = {}
|
|
self.cache: dict[Any, Post] = {}
|
|
for key, val in self.categories.items():
|
|
self.reverse_category[val] = key
|
|
|
|
@staticmethod
|
|
async def get_account_name(target: Target) -> Optional[str]:
|
|
"Given a tareget, return the username(name) of the target"
|
|
raise NotImplementedError()
|
|
|
|
async def get_sub_list(self, target: Target) -> list[RawPost]:
|
|
"Get post list of the given target"
|
|
raise NotImplementedError()
|
|
|
|
def get_id(self, post: RawPost) -> Any:
|
|
"Get post id of given RawPost"
|
|
raise NotImplementedError()
|
|
|
|
def get_date(self, post: RawPost) -> Optional[int]:
|
|
"Get post timestamp and return, return None if can't get the time"
|
|
raise NotImplementedError()
|
|
|
|
def get_category(self, post: RawPost) -> Optional[Category]:
|
|
"Return category of given Rawpost"
|
|
raise NotImplementedError()
|
|
|
|
def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]:
|
|
"Return Tag list of given RawPost"
|
|
raise NotImplementedError()
|
|
|
|
async def parse(self, raw_post: RawPost) -> Post:
|
|
"parse RawPost into post"
|
|
raise NotImplementedError()
|
|
|
|
def filter_platform_custom(self, post: RawPost) -> bool:
|
|
raise NotImplementedError()
|
|
|
|
async def _parse_with_cache(self, post: RawPost) -> Post:
|
|
post_id = self.get_id(post)
|
|
if post_id not in self.cache:
|
|
self.cache[post_id] = await self.parse(post)
|
|
return self.cache[post_id]
|
|
|
|
async def filter_common(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]:
|
|
if not self.inited.get(target, False):
|
|
# target not init
|
|
for raw_post in raw_post_list:
|
|
post_id = self.get_id(raw_post)
|
|
self.exists_posts[target].add(post_id)
|
|
logger.info('init {}-{} with {}'.format(self.platform_name, target, self.exists_posts[target]))
|
|
self.inited[target] = True
|
|
return []
|
|
res: list[RawPost] = []
|
|
for raw_post in raw_post_list:
|
|
post_id = self.get_id(raw_post)
|
|
if post_id in self.exists_posts[target]:
|
|
continue
|
|
if (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60:
|
|
continue
|
|
try:
|
|
if not self.filter_platform_custom(raw_post):
|
|
continue
|
|
except NotImplementedError:
|
|
pass
|
|
try:
|
|
self.get_category(raw_post)
|
|
except CategoryNotSupport:
|
|
continue
|
|
except NotImplementedError:
|
|
pass
|
|
res.append(raw_post)
|
|
self.exists_posts[target].add(post_id)
|
|
return res
|
|
|
|
async def filter_user_custom(self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]) -> list[RawPost]:
|
|
res: list[RawPost] = []
|
|
for raw_post in raw_post_list:
|
|
if self.categories:
|
|
cat = self.get_category(raw_post)
|
|
if cats and cat not in cats:
|
|
continue
|
|
if self.enable_tag and tags:
|
|
flag = False
|
|
post_tags = self.get_tags(raw_post)
|
|
for tag in post_tags:
|
|
if tag in tags:
|
|
flag = True
|
|
break
|
|
if not flag:
|
|
continue
|
|
res.append(raw_post)
|
|
return res
|
|
|
|
async def fetch_new_post(self, target: Target, users: list[User]) -> list[tuple[User, list[Post]]]:
|
|
try:
|
|
config = Config()
|
|
post_list = await self.get_sub_list(target)
|
|
new_posts = await self.filter_common(target, post_list)
|
|
res: list[tuple[User, list[Post]]] = []
|
|
if not new_posts:
|
|
return []
|
|
else:
|
|
for post in new_posts:
|
|
logger.info('fetch new post from {} {}: {}'.format(self.platform_name, target, self.get_id(post)))
|
|
for user in users:
|
|
required_tags = config.get_sub_tags(self.platform_name, target, user.user_type, user.user) if self.enable_tag else []
|
|
cats = config.get_sub_category(self.platform_name, target, user.user_type, user.user)
|
|
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
|
|
user_post: list[Post] = []
|
|
for raw_post in user_raw_post:
|
|
user_post.append(await self._parse_with_cache(raw_post))
|
|
res.append((user, user_post))
|
|
self.cache = {}
|
|
return res
|
|
except httpx.RequestError as err:
|
|
logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url))
|
|
return []
|
|
|
|
|
|
class PlatformNoTarget(PlatformProto):
|
|
|
|
categories: dict[Category, str]
|
|
has_target = False
|
|
platform_name: str
|
|
enable_tag: bool
|
|
|
|
def __init__(self):
|
|
self.exists_posts = set()
|
|
self.inited = False
|
|
self.reverse_category = {}
|
|
self.cache: dict[Any, Post] = {}
|
|
for key, val in self.categories.items():
|
|
self.reverse_category[val] = key
|
|
|
|
@staticmethod
|
|
async def get_account_name(target: Target) -> Optional[str]:
|
|
"return the username(name) of the target"
|
|
raise NotImplementedError()
|
|
|
|
async def get_sub_list(self) -> list[RawPost]:
|
|
"Get post list of the given target"
|
|
raise NotImplementedError()
|
|
|
|
def get_id(self, post: RawPost) -> Any:
|
|
"Get post id of given RawPost"
|
|
raise NotImplementedError()
|
|
|
|
def get_date(self, post: RawPost) -> Optional[int]:
|
|
"Get post timestamp and return, return None if can't get the time"
|
|
raise NotImplementedError()
|
|
|
|
def get_category(self, post: RawPost) -> Optional[Category]:
|
|
"Return category of given Rawpost"
|
|
raise NotImplementedError()
|
|
|
|
def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]:
|
|
"Return Tag list of given RawPost"
|
|
raise NotImplementedError()
|
|
|
|
async def parse(self, raw_post: RawPost) -> Post:
|
|
"parse RawPost into post"
|
|
raise NotImplementedError()
|
|
|
|
def filter_platform_custom(self, post: RawPost) -> bool:
|
|
raise NotImplementedError()
|
|
|
|
async def _parse_with_cache(self, post: RawPost) -> Post:
|
|
post_id = self.get_id(post)
|
|
if post_id not in self.cache:
|
|
self.cache[post_id] = await self.parse(post)
|
|
return self.cache[post_id]
|
|
|
|
async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]:
|
|
if not self.inited:
|
|
# target not init
|
|
for raw_post in raw_post_list:
|
|
post_id = self.get_id(raw_post)
|
|
self.exists_posts.add(post_id)
|
|
logger.info('init {} with {}'.format(self.platform_name, self.exists_posts))
|
|
self.inited = True
|
|
return []
|
|
res: list[RawPost] = []
|
|
for raw_post in raw_post_list:
|
|
post_id = self.get_id(raw_post)
|
|
if post_id in self.exists_posts:
|
|
continue
|
|
if (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60:
|
|
continue
|
|
try:
|
|
if not self.filter_platform_custom(raw_post):
|
|
continue
|
|
except NotImplementedError:
|
|
pass
|
|
try:
|
|
self.get_category(raw_post)
|
|
except CategoryNotSupport:
|
|
continue
|
|
res.append(raw_post)
|
|
self.exists_posts.add(post_id)
|
|
return res
|
|
|
|
async def filter_user_custom(self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]) -> list[RawPost]:
|
|
res: list[RawPost] = []
|
|
for raw_post in raw_post_list:
|
|
if self.categories:
|
|
cat = self.get_category(raw_post)
|
|
if cats and cat not in cats:
|
|
continue
|
|
if self.enable_tag and tags:
|
|
flag = False
|
|
post_tags = self.get_tags(raw_post)
|
|
for tag in post_tags:
|
|
if tag in tags:
|
|
flag = True
|
|
break
|
|
if not flag:
|
|
continue
|
|
res.append(raw_post)
|
|
return res
|
|
|
|
async def fetch_new_post(self, users: list[User]) -> list[tuple[User, list[Post]]]:
|
|
try:
|
|
config = Config()
|
|
post_list = await self.get_sub_list()
|
|
new_posts = await self.filter_common(post_list)
|
|
res: list[tuple[User, list[Post]]] = []
|
|
if not new_posts:
|
|
return []
|
|
else:
|
|
for post in new_posts:
|
|
logger.info('fetch new post from {}: {}'.format(self.platform_name, self.get_id(post)))
|
|
for user in users:
|
|
required_tags = config.get_sub_tags(self.platform_name, 'default', user.user_type, user.user) if self.enable_tag else []
|
|
cats = config.get_sub_category(self.platform_name, 'default', user.user_type, user.user)
|
|
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
|
|
user_post: list[Post] = []
|
|
for raw_post in user_raw_post:
|
|
user_post.append(await self._parse_with_cache(raw_post))
|
|
res.append((user, user_post))
|
|
self.cache = {}
|
|
return res
|
|
except httpx.RequestError as err:
|
|
logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url))
|
|
return []
|