mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-06 20:06:12 +08:00
144 lines
6.0 KiB
Python
144 lines
6.0 KiB
Python
from collections import defaultdict
|
||
from os import path
|
||
import os
|
||
from typing import DefaultDict, Mapping
|
||
|
||
import nonebot
|
||
from tinydb import Query, TinyDB
|
||
|
||
from .plugin_config import plugin_config
|
||
from .types import Target, User
|
||
from .utils import Singleton
|
||
from .platform import platform_manager
|
||
|
||
supported_target_type = platform_manager.keys()
|
||
|
||
def get_config_path() -> str:
|
||
if plugin_config.bison_config_path:
|
||
data_dir = plugin_config.bison_config_path
|
||
else:
|
||
working_dir = os.getcwd()
|
||
data_dir = path.join(working_dir, 'data')
|
||
if not path.isdir(data_dir):
|
||
os.makedirs(data_dir)
|
||
old_path = path.join(data_dir, 'hk_reporter.json')
|
||
new_path = path.join(data_dir, 'bison.json')
|
||
if os.path.exists(old_path) and not os.path.exists(new_path):
|
||
os.rename(old_path, new_path)
|
||
return new_path
|
||
|
||
class NoSuchUserException(Exception):
|
||
pass
|
||
|
||
class NoSuchSubscribeException(Exception):
|
||
pass
|
||
|
||
class Config(metaclass=Singleton):
|
||
|
||
migrate_version = 2
|
||
|
||
def __init__(self):
|
||
self.db = TinyDB(get_config_path(), encoding='utf-8')
|
||
self.kv_config = self.db.table('kv')
|
||
self.user_target = self.db.table('user_target')
|
||
self.target_user_cache: dict[str, defaultdict[Target, list[User]]] = {}
|
||
self.target_user_cat_cache = {}
|
||
self.target_user_tag_cache = {}
|
||
self.target_list = {}
|
||
self.next_index: DefaultDict[str, int] = defaultdict(lambda: 0)
|
||
|
||
def add_subscribe(self, user, user_type, target, target_name, target_type, cats, tags):
|
||
user_query = Query()
|
||
query = (user_query.user == user) & (user_query.user_type == user_type)
|
||
if (user_data := self.user_target.get(query)):
|
||
# update
|
||
subs: list = user_data.get('subs', [])
|
||
subs.append({"target": target, "target_type": target_type, 'target_name': target_name, 'cats': cats, 'tags': tags})
|
||
self.user_target.update({"subs": subs}, query)
|
||
else:
|
||
# insert
|
||
self.user_target.insert({
|
||
'user': user, 'user_type': user_type,
|
||
'subs': [{'target': target, 'target_type': target_type, 'target_name': target_name, 'cats': cats, 'tags': tags }]
|
||
})
|
||
self.update_send_cache()
|
||
|
||
def list_subscribe(self, user, user_type):
|
||
query = Query()
|
||
if user_sub := self.user_target.get((query.user == user) & (query.user_type ==user_type)):
|
||
return user_sub['subs']
|
||
return []
|
||
|
||
def get_all_subscribe(self):
|
||
return self.user_target
|
||
|
||
def del_subscribe(self, user, user_type, target, target_type):
|
||
user_query = Query()
|
||
query = (user_query.user == user) & (user_query.user_type == user_type)
|
||
if not (query_res := self.user_target.get(query)):
|
||
raise NoSuchUserException()
|
||
subs = query_res.get('subs', [])
|
||
for idx, sub in enumerate(subs):
|
||
if sub.get('target') == target and sub.get('target_type') == target_type:
|
||
subs.pop(idx)
|
||
self.user_target.update({'subs': subs}, query)
|
||
self.update_send_cache()
|
||
return
|
||
raise NoSuchSubscribeException()
|
||
|
||
def update_send_cache(self):
|
||
res = {target_type: defaultdict(list) for target_type in supported_target_type}
|
||
cat_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
|
||
tag_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
|
||
# res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type}
|
||
for user in self.user_target.all():
|
||
for sub in user.get('subs', []):
|
||
if not sub.get('target_type') in supported_target_type:
|
||
continue
|
||
res[sub['target_type']][sub['target']].append(User(user['user'], user['user_type']))
|
||
cat_res[sub['target_type']][sub['target']]['{}-{}'.format(user['user_type'], user['user'])] = sub['cats']
|
||
tag_res[sub['target_type']][sub['target']]['{}-{}'.format(user['user_type'], user['user'])] = sub['tags']
|
||
self.target_user_cache = res
|
||
self.target_user_cat_cache = cat_res
|
||
self.target_user_tag_cache = tag_res
|
||
for target_type in self.target_user_cache:
|
||
self.target_list[target_type] = list(self.target_user_cache[target_type].keys())
|
||
|
||
def get_sub_category(self, target_type, target, user_type, user):
|
||
return self.target_user_cat_cache[target_type][target]['{}-{}'.format(user_type, user)]
|
||
|
||
def get_sub_tags(self, target_type, target, user_type, user):
|
||
return self.target_user_tag_cache[target_type][target]['{}-{}'.format(user_type, user)]
|
||
|
||
def get_next_target(self, target_type):
|
||
# FIXME 插入或删除target后对队列的影响(但是并不是大问题
|
||
if not self.target_list[target_type]:
|
||
return None
|
||
self.next_index[target_type] %= len(self.target_list[target_type])
|
||
res = self.target_list[target_type][self.next_index[target_type]]
|
||
self.next_index[target_type] += 1
|
||
return res
|
||
|
||
def start_up():
|
||
config = Config()
|
||
if not (search_res := config.kv_config.search(Query().name=="version")):
|
||
config.kv_config.insert({"name": "version", "value": config.migrate_version})
|
||
elif search_res[0].get("value") < config.migrate_version:
|
||
query = Query()
|
||
version_query = (query.name == 'version')
|
||
cur_version = search_res[0].get("value")
|
||
if cur_version == 1:
|
||
cur_version = 2
|
||
for user_conf in config.user_target.all():
|
||
conf_id = user_conf.doc_id
|
||
subs = user_conf['subs']
|
||
for sub in subs:
|
||
sub['cats'] = []
|
||
sub['tags'] = []
|
||
config.user_target.update({'subs': subs}, doc_ids=[conf_id])
|
||
config.kv_config.update({"value": config.migrate_version}, version_query)
|
||
# do migration
|
||
config.update_send_cache()
|
||
|
||
nonebot.get_driver().on_startup(start_up)
|