2021-02-16 23:20:41 +08:00

286 lines
11 KiB
Python

import time
from collections import defaultdict
from typing import Any, Optional
import httpx
from nonebot import logger
from ..config import Config
from ..post import Post
from ..types import Category, RawPost, Tag, Target, User
from ..utils import Singleton
class CategoryNotSupport(Exception):
"raise in get_category, when post category is not supported"
pass
class PlatformProto(metaclass=Singleton):
categories: dict[Category, str]
reverse_category: dict[str, Category]
has_target: bool
platform_name: str
enable_tag: bool
async def fetch_new_post(self, target: Target, users: list[User]) -> list[tuple[User, list[Post]]]:
...
@staticmethod
async def get_account_name(target: Target) -> Optional[str]:
...
class Platform(PlatformProto):
"platform with target(account), like weibo, bilibili"
categories: dict[Category, str]
has_target: bool = True
platform_name: str
enable_tag: bool
def __init__(self):
self.exists_posts = defaultdict(set)
self.inited = dict()
self.reverse_category = {}
self.cache: dict[Any, Post] = {}
for key, val in self.categories.items():
self.reverse_category[val] = key
@staticmethod
async def get_account_name(target: Target) -> Optional[str]:
"Given a tareget, return the username(name) of the target"
raise NotImplementedError()
async def get_sub_list(self, target: Target) -> list[RawPost]:
"Get post list of the given target"
raise NotImplementedError()
def get_id(self, post: RawPost) -> Any:
"Get post id of given RawPost"
raise NotImplementedError()
def get_date(self, post: RawPost) -> Optional[int]:
"Get post timestamp and return, return None if can't get the time"
raise NotImplementedError()
def get_category(self, post: RawPost) -> Optional[Category]:
"Return category of given Rawpost"
raise NotImplementedError()
def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]:
"Return Tag list of given RawPost"
raise NotImplementedError()
async def parse(self, raw_post: RawPost) -> Post:
"parse RawPost into post"
raise NotImplementedError()
def filter_platform_custom(self, post: RawPost) -> bool:
raise NotImplementedError()
async def _parse_with_cache(self, post: RawPost) -> Post:
post_id = self.get_id(post)
if post_id not in self.cache:
self.cache[post_id] = await self.parse(post)
return self.cache[post_id]
async def filter_common(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]:
if not self.inited.get(target, False):
# target not init
for raw_post in raw_post_list:
post_id = self.get_id(raw_post)
self.exists_posts[target].add(post_id)
logger.info('init {}-{} with {}'.format(self.platform_name, target, self.exists_posts[target]))
self.inited[target] = True
return []
res: list[RawPost] = []
for raw_post in raw_post_list:
post_id = self.get_id(raw_post)
if post_id in self.exists_posts[target]:
continue
if (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60:
continue
try:
if not self.filter_platform_custom(raw_post):
continue
except NotImplementedError:
pass
try:
self.get_category(raw_post)
except CategoryNotSupport:
continue
except NotImplementedError:
pass
res.append(raw_post)
self.exists_posts[target].add(post_id)
return res
async def filter_user_custom(self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]) -> list[RawPost]:
res: list[RawPost] = []
for raw_post in raw_post_list:
if self.categories:
cat = self.get_category(raw_post)
if cats and cat not in cats:
continue
if self.enable_tag and tags:
flag = False
post_tags = self.get_tags(raw_post)
for tag in post_tags:
if tag in tags:
flag = True
break
if not flag:
continue
res.append(raw_post)
return res
async def fetch_new_post(self, target: Target, users: list[User]) -> list[tuple[User, list[Post]]]:
try:
config = Config()
post_list = await self.get_sub_list(target)
new_posts = await self.filter_common(target, post_list)
res: list[tuple[User, list[Post]]] = []
if not new_posts:
return []
else:
for post in new_posts:
logger.info('fetch new post from {} {}: {}'.format(self.platform_name, target, self.get_id(post)))
for user in users:
required_tags = config.get_sub_tags(self.platform_name, target, user.user_type, user.user) if self.enable_tag else []
cats = config.get_sub_category(self.platform_name, target, user.user_type, user.user)
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
user_post: list[Post] = []
for raw_post in user_raw_post:
user_post.append(await self._parse_with_cache(raw_post))
res.append((user, user_post))
self.cache = {}
return res
except httpx.RequestError as err:
logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url))
return []
class PlatformNoTarget(PlatformProto):
categories: dict[Category, str]
has_target = False
platform_name: str
enable_tag: bool
def __init__(self):
self.exists_posts = set()
self.inited = False
self.reverse_category = {}
self.cache: dict[Any, Post] = {}
for key, val in self.categories.items():
self.reverse_category[val] = key
@staticmethod
async def get_account_name(target: Target) -> Optional[str]:
"return the username(name) of the target"
raise NotImplementedError()
async def get_sub_list(self) -> list[RawPost]:
"Get post list of the given target"
raise NotImplementedError()
def get_id(self, post: RawPost) -> Any:
"Get post id of given RawPost"
raise NotImplementedError()
def get_date(self, post: RawPost) -> Optional[int]:
"Get post timestamp and return, return None if can't get the time"
raise NotImplementedError()
def get_category(self, post: RawPost) -> Optional[Category]:
"Return category of given Rawpost"
raise NotImplementedError()
def get_tags(self, raw_post: RawPost) -> Optional[list[Tag]]:
"Return Tag list of given RawPost"
raise NotImplementedError()
async def parse(self, raw_post: RawPost) -> Post:
"parse RawPost into post"
raise NotImplementedError()
def filter_platform_custom(self, post: RawPost) -> bool:
raise NotImplementedError()
async def _parse_with_cache(self, post: RawPost) -> Post:
post_id = self.get_id(post)
if post_id not in self.cache:
self.cache[post_id] = await self.parse(post)
return self.cache[post_id]
async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]:
if not self.inited:
# target not init
for raw_post in raw_post_list:
post_id = self.get_id(raw_post)
self.exists_posts.add(post_id)
logger.info('init {} with {}'.format(self.platform_name, self.exists_posts))
self.inited = True
return []
res: list[RawPost] = []
for raw_post in raw_post_list:
post_id = self.get_id(raw_post)
if post_id in self.exists_posts:
continue
if (post_time := self.get_date(raw_post)) and time.time() - post_time > 2 * 60 * 60:
continue
try:
if not self.filter_platform_custom(raw_post):
continue
except NotImplementedError:
pass
try:
self.get_category(raw_post)
except CategoryNotSupport:
continue
res.append(raw_post)
self.exists_posts.add(post_id)
return res
async def filter_user_custom(self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag]) -> list[RawPost]:
res: list[RawPost] = []
for raw_post in raw_post_list:
if self.categories:
cat = self.get_category(raw_post)
if cats and cat not in cats:
continue
if self.enable_tag and tags:
flag = False
post_tags = self.get_tags(raw_post)
for tag in post_tags:
if tag in tags:
flag = True
break
if not flag:
continue
res.append(raw_post)
return res
async def fetch_new_post(self, users: list[User]) -> list[tuple[User, list[Post]]]:
try:
config = Config()
post_list = await self.get_sub_list()
new_posts = await self.filter_common(post_list)
res: list[tuple[User, list[Post]]] = []
if not new_posts:
return []
else:
for post in new_posts:
logger.info('fetch new post from {}: {}'.format(self.platform_name, self.get_id(post)))
for user in users:
required_tags = config.get_sub_tags(self.platform_name, 'default', user.user_type, user.user) if self.enable_tag else []
cats = config.get_sub_category(self.platform_name, 'default', user.user_type, user.user)
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
user_post: list[Post] = []
for raw_post in user_raw_post:
user_post.append(await self._parse_with_cache(raw_post))
res.append((user, user_post))
self.cache = {}
return res
except httpx.RequestError as err:
logger.warning("network connection error: {}, url: {}".format(type(err), err.request.url))
return []