fix circular import

This commit is contained in:
felinae98 2021-04-30 15:33:56 +08:00
parent 2a053422ee
commit b9b9464611
No known key found for this signature in database
GPG Key ID: 00C8B010587FF610
4 changed files with 28 additions and 19 deletions

View File

@ -1,13 +1,13 @@
from collections import defaultdict from collections import defaultdict
from os import path from os import path
import os import os
from typing import DefaultDict from typing import DefaultDict, Mapping
import nonebot import nonebot
from tinydb import Query, TinyDB from tinydb import Query, TinyDB
from .plugin_config import plugin_config from .plugin_config import plugin_config
from .types import User from .types import Target, User
from .utils import Singleton from .utils import Singleton
from .platform import platform_manager from .platform import platform_manager
@ -37,7 +37,7 @@ class Config(metaclass=Singleton):
self.db = TinyDB(get_config_path(), encoding='utf-8') self.db = TinyDB(get_config_path(), encoding='utf-8')
self.kv_config = self.db.table('kv') self.kv_config = self.db.table('kv')
self.user_target = self.db.table('user_target') self.user_target = self.db.table('user_target')
self.target_user_cache = {} self.target_user_cache: dict[str, defaultdict[Target, list[User]]] = {}
self.target_user_cat_cache = {} self.target_user_cat_cache = {}
self.target_user_tag_cache = {} self.target_user_tag_cache = {}
self.target_list = {} self.target_list = {}

View File

@ -6,10 +6,9 @@ from typing import Any, Collection, Optional
import httpx import httpx
from nonebot import logger from nonebot import logger
from ..config import Config
from ..plugin_config import plugin_config from ..plugin_config import plugin_config
from ..post import Post from ..post import Post
from ..types import Category, RawPost, Tag, Target, User from ..types import Category, RawPost, Tag, Target, User, UserSubInfo
class CategoryNotSupport(Exception): class CategoryNotSupport(Exception):
@ -49,7 +48,7 @@ class PlatformProto(metaclass=RegistryMeta):
schedule_interval: int schedule_interval: int
@abstractmethod @abstractmethod
async def fetch_new_post(self, target: Target, users: list[User]) -> list[tuple[User, list[Post]]]: async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]:
... ...
@staticmethod @staticmethod
@ -172,9 +171,8 @@ class Platform(PlatformProto):
return [] return []
return self._do_filter_common(raw_post_list, self.exists_posts[target]) return self._do_filter_common(raw_post_list, self.exists_posts[target])
async def fetch_new_post(self, target: Target, users: list[User]) -> list[tuple[User, list[Post]]]: async def fetch_new_post(self, target: Target, users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]:
try: try:
config = Config()
post_list = await self.get_sub_list(target) post_list = await self.get_sub_list(target)
new_posts = await self.filter_common(target, post_list) new_posts = await self.filter_common(target, post_list)
res: list[tuple[User, list[Post]]] = [] res: list[tuple[User, list[Post]]] = []
@ -183,9 +181,9 @@ class Platform(PlatformProto):
else: else:
for post in new_posts: for post in new_posts:
logger.info('fetch new post from {} {}: {}'.format(self.platform_name, target, self.get_id(post))) logger.info('fetch new post from {} {}: {}'.format(self.platform_name, target, self.get_id(post)))
for user in users: for user, category_getter, tag_getter in users:
required_tags = config.get_sub_tags(self.platform_name, target, user.user_type, user.user) if self.enable_tag else [] required_tags = tag_getter(target) if self.enable_tag else []
cats = config.get_sub_category(self.platform_name, target, user.user_type, user.user) cats = category_getter(target)
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags) user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
user_post: list[Post] = [] user_post: list[Post] = []
for raw_post in user_raw_post: for raw_post in user_raw_post:
@ -228,9 +226,8 @@ class PlatformNoTarget(PlatformProto):
return [] return []
return self._do_filter_common(raw_post_list, self.exists_posts) return self._do_filter_common(raw_post_list, self.exists_posts)
async def fetch_new_post(self, _: Target, users: list[User]) -> list[tuple[User, list[Post]]]: async def fetch_new_post(self, _: Target, users: list[UserSubInfo]) -> list[tuple[User, list[Post]]]:
try: try:
config = Config()
post_list = await self.get_sub_list() post_list = await self.get_sub_list()
new_posts = await self.filter_common(post_list) new_posts = await self.filter_common(post_list)
res: list[tuple[User, list[Post]]] = [] res: list[tuple[User, list[Post]]] = []
@ -239,9 +236,9 @@ class PlatformNoTarget(PlatformProto):
else: else:
for post in new_posts: for post in new_posts:
logger.info('fetch new post from {}: {}'.format(self.platform_name, self.get_id(post))) logger.info('fetch new post from {}: {}'.format(self.platform_name, self.get_id(post)))
for user in users: for user, category_getter, tag_getter in users:
required_tags = config.get_sub_tags(self.platform_name, 'default', user.user_type, user.user) if self.enable_tag else [] required_tags = tag_getter(Target('default'))
cats = config.get_sub_category(self.platform_name, 'default', user.user_type, user.user) cats = category_getter(Target('default'))
user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags) user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags)
user_post: list[Post] = [] user_post: list[Post] = []
for raw_post in user_raw_post: for raw_post in user_raw_post:

View File

@ -6,6 +6,7 @@ from .config import Config
from .platform import platform_manager from .platform import platform_manager
from .send import do_send_msgs from .send import do_send_msgs
from .send import send_msgs from .send import send_msgs
from .types import UserSubInfo
scheduler = AsyncIOScheduler() scheduler = AsyncIOScheduler()
@ -21,10 +22,16 @@ async def fetch_and_send(target_type: str):
if not target: if not target:
return return
logger.debug('try to fecth new posts from {}, target: {}'.format(target_type, target)) logger.debug('try to fecth new posts from {}, target: {}'.format(target_type, target))
send_list = config.target_user_cache[target_type][target] send_user_list = config.target_user_cache[target_type][target]
send_userinfo_list = list(map(
lambda user: UserSubInfo(
user,
lambda target: config.get_sub_category(target_type, target, user.user_type, user.user),
lambda target: config.get_sub_tags(target_type, target, user.user_type, user.user)
), send_user_list))
bot_list = list(nonebot.get_bots().values()) bot_list = list(nonebot.get_bots().values())
bot = bot_list[0] if bot_list else None bot = bot_list[0] if bot_list else None
to_send = await platform_manager[target_type].fetch_new_post(target, send_list) to_send = await platform_manager[target_type].fetch_new_post(target, send_userinfo_list)
for user, send_list in to_send: for user, send_list in to_send:
for send_post in send_list: for send_post in send_list:
logger.info('send to {}: {}'.format(user, send_post)) logger.info('send to {}: {}'.format(user, send_post))

View File

@ -1,4 +1,4 @@
from typing import Any, NewType from typing import Any, Callable, NamedTuple, NewType
from dataclasses import dataclass from dataclasses import dataclass
RawPost = NewType('RawPost', Any) RawPost = NewType('RawPost', Any)
@ -10,3 +10,8 @@ Tag = NewType('Tag', str)
class User: class User:
user: str user: str
user_type: str user_type: str
class UserSubInfo(NamedTuple):
user: User
category_getter: Callable[[Target], list[Category]]
tag_getter: Callable[[Target], list[Tag]]