mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2026-06-23 22:16:53 +08:00
format code
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import socketio
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from nonebot import get_driver, on_command
|
||||
from nonebot.adapters.cqhttp.bot import Bot
|
||||
@@ -11,7 +12,6 @@ from nonebot.drivers.fastapi import Driver
|
||||
from nonebot.log import logger
|
||||
from nonebot.rule import to_me
|
||||
from nonebot.typing import T_State
|
||||
import socketio
|
||||
|
||||
from ..plugin_config import plugin_config
|
||||
from .api import (
|
||||
@@ -27,21 +27,21 @@ from .api import (
|
||||
from .jwt import load_jwt
|
||||
from .token_manager import token_manager as tm
|
||||
|
||||
URL_BASE = '/bison/'
|
||||
GLOBAL_CONF_URL = f'{URL_BASE}api/global_conf'
|
||||
AUTH_URL = f'{URL_BASE}api/auth'
|
||||
SUBSCRIBE_URL = f'{URL_BASE}api/subs'
|
||||
GET_TARGET_NAME_URL = f'{URL_BASE}api/target_name'
|
||||
TEST_URL = f'{URL_BASE}test'
|
||||
URL_BASE = "/bison/"
|
||||
GLOBAL_CONF_URL = f"{URL_BASE}api/global_conf"
|
||||
AUTH_URL = f"{URL_BASE}api/auth"
|
||||
SUBSCRIBE_URL = f"{URL_BASE}api/subs"
|
||||
GET_TARGET_NAME_URL = f"{URL_BASE}api/target_name"
|
||||
TEST_URL = f"{URL_BASE}test"
|
||||
|
||||
STATIC_PATH = (Path(__file__).parent / "dist").resolve()
|
||||
|
||||
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
socket_app = socketio.ASGIApp(sio, socketio_path="socket")
|
||||
|
||||
class SinglePageApplication(StaticFiles):
|
||||
|
||||
def __init__(self, directory: os.PathLike, index='index.html'):
|
||||
class SinglePageApplication(StaticFiles):
|
||||
def __init__(self, directory: os.PathLike, index="index.html"):
|
||||
self.index = index
|
||||
super().__init__(directory=directory, packages=None, html=True, check_dir=True)
|
||||
|
||||
@@ -51,12 +51,13 @@ class SinglePageApplication(StaticFiles):
|
||||
return await super().lookup_path(self.index)
|
||||
return (full_path, stat_res)
|
||||
|
||||
def register_router_fastapi(driver: Driver, socketio):
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.param_functions import Depends
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
oath_scheme = OAuth2PasswordBearer(tokenUrl='token')
|
||||
def register_router_fastapi(driver: Driver, socketio):
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.param_functions import Depends
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
oath_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
async def get_jwt_obj(token: str = Depends(oath_scheme)):
|
||||
obj = load_jwt(token)
|
||||
@@ -64,10 +65,12 @@ def register_router_fastapi(driver: Driver, socketio):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
return obj
|
||||
|
||||
async def check_group_permission(groupNumber: str, token_obj: dict = Depends(get_jwt_obj)):
|
||||
groups = token_obj['groups']
|
||||
async def check_group_permission(
|
||||
groupNumber: str, token_obj: dict = Depends(get_jwt_obj)
|
||||
):
|
||||
groups = token_obj["groups"]
|
||||
for group in groups:
|
||||
if int(groupNumber) == group['id']:
|
||||
if int(groupNumber) == group["id"]:
|
||||
return
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
|
||||
@@ -88,17 +91,35 @@ def register_router_fastapi(driver: Driver, socketio):
|
||||
@app.get(SUBSCRIBE_URL)
|
||||
async def subs(jwt_obj: dict = Depends(get_jwt_obj)):
|
||||
return await get_subs_info(jwt_obj)
|
||||
|
||||
@app.get(GET_TARGET_NAME_URL)
|
||||
async def _get_target_name(platformName: str, target: str, jwt_obj: dict = Depends(get_jwt_obj)):
|
||||
async def _get_target_name(
|
||||
platformName: str, target: str, jwt_obj: dict = Depends(get_jwt_obj)
|
||||
):
|
||||
return await get_target_name(platformName, target, jwt_obj)
|
||||
|
||||
@app.post(SUBSCRIBE_URL, dependencies=[Depends(check_group_permission)])
|
||||
async def _add_group_subs(groupNumber: str, req: AddSubscribeReq):
|
||||
return await add_group_sub(group_number=groupNumber, platform_name=req.platformName,
|
||||
target=req.target, target_name=req.targetName, cats=req.cats, tags=req.tags)
|
||||
return await add_group_sub(
|
||||
group_number=groupNumber,
|
||||
platform_name=req.platformName,
|
||||
target=req.target,
|
||||
target_name=req.targetName,
|
||||
cats=req.cats,
|
||||
tags=req.tags,
|
||||
)
|
||||
|
||||
@app.patch(SUBSCRIBE_URL, dependencies=[Depends(check_group_permission)])
|
||||
async def _update_group_subs(groupNumber: str, req: AddSubscribeReq):
|
||||
return await update_group_sub(group_number=groupNumber, platform_name=req.platformName,
|
||||
target=req.target, target_name=req.targetName, cats=req.cats, tags=req.tags)
|
||||
return await update_group_sub(
|
||||
group_number=groupNumber,
|
||||
platform_name=req.platformName,
|
||||
target=req.target,
|
||||
target_name=req.targetName,
|
||||
cats=req.cats,
|
||||
tags=req.tags,
|
||||
)
|
||||
|
||||
@app.delete(SUBSCRIBE_URL, dependencies=[Depends(check_group_permission)])
|
||||
async def _del_group_subs(groupNumber: str, target: str, platformName: str):
|
||||
return await del_group_sub(groupNumber, platformName, target)
|
||||
@@ -108,8 +129,8 @@ def register_router_fastapi(driver: Driver, socketio):
|
||||
|
||||
def init():
|
||||
driver = get_driver()
|
||||
if driver.type == 'fastapi':
|
||||
assert(isinstance(driver, Driver))
|
||||
if driver.type == "fastapi":
|
||||
assert isinstance(driver, Driver)
|
||||
register_router_fastapi(driver, socket_app)
|
||||
else:
|
||||
logger.warning(f"Driver {driver.type} not supported")
|
||||
@@ -118,19 +139,25 @@ def init():
|
||||
port = driver.config.port
|
||||
if host in ["0.0.0.0", "127.0.0.1"]:
|
||||
host = "localhost"
|
||||
logger.opt(colors=True).info(f"Nonebot test frontend will be running at: "
|
||||
f"<b><u>http://{host}:{port}{URL_BASE}</u></b>")
|
||||
logger.opt(colors=True).info(
|
||||
f"Nonebot test frontend will be running at: "
|
||||
f"<b><u>http://{host}:{port}{URL_BASE}</u></b>"
|
||||
)
|
||||
|
||||
if (STATIC_PATH / 'index.html').exists():
|
||||
|
||||
if (STATIC_PATH / "index.html").exists():
|
||||
init()
|
||||
|
||||
get_token = on_command('后台管理', rule=to_me(), priority=5)
|
||||
get_token = on_command("后台管理", rule=to_me(), priority=5)
|
||||
|
||||
@get_token.handle()
|
||||
async def send_token(bot: "Bot", event: PrivateMessageEvent, state: T_State):
|
||||
token = tm.get_user_token((event.get_user_id(), event.sender.nickname))
|
||||
await get_token.finish(f'请访问: {plugin_config.bison_outer_url}auth/{token}')
|
||||
get_token.__help__name__ = '获取后台管理地址'
|
||||
get_token.__help__info__ = ('获取管理bot后台的地址,该地址会'
|
||||
'在一段时间过后过期,请不要泄漏该地址')
|
||||
await get_token.finish(f"请访问: {plugin_config.bison_outer_url}auth/{token}")
|
||||
|
||||
get_token.__help__name__ = "获取后台管理地址"
|
||||
get_token.__help__info__ = "获取管理bot后台的地址,该地址会" "在一段时间过后过期,请不要泄漏该地址"
|
||||
else:
|
||||
logger.warning("Frontend file not found, please compile it or use docker or pypi version")
|
||||
logger.warning(
|
||||
"Frontend file not found, please compile it or use docker or pypi version"
|
||||
)
|
||||
|
||||
@@ -6,110 +6,141 @@ from ..platform import check_sub_target, platform_manager
|
||||
from .jwt import pack_jwt
|
||||
from .token_manager import token_manager
|
||||
|
||||
|
||||
async def test():
|
||||
return {"status": 200, "text": "test"}
|
||||
|
||||
|
||||
async def get_global_conf():
|
||||
res = {}
|
||||
for platform_name, platform in platform_manager.items():
|
||||
res[platform_name] = {
|
||||
'platformName': platform_name,
|
||||
'categories': platform.categories,
|
||||
'enabledTag': platform.enable_tag,
|
||||
'name': platform.name,
|
||||
'hasTarget': getattr(platform, 'has_target')
|
||||
}
|
||||
return { 'platformConf': res }
|
||||
"platformName": platform_name,
|
||||
"categories": platform.categories,
|
||||
"enabledTag": platform.enable_tag,
|
||||
"name": platform.name,
|
||||
"hasTarget": getattr(platform, "has_target"),
|
||||
}
|
||||
return {"platformConf": res}
|
||||
|
||||
|
||||
async def get_admin_groups(qq: int):
|
||||
bot = nonebot.get_bot()
|
||||
groups = await bot.call_api('get_group_list')
|
||||
groups = await bot.call_api("get_group_list")
|
||||
res = []
|
||||
for group in groups:
|
||||
group_id = group['group_id']
|
||||
users = await bot.call_api('get_group_member_list', group_id=group_id)
|
||||
group_id = group["group_id"]
|
||||
users = await bot.call_api("get_group_member_list", group_id=group_id)
|
||||
for user in users:
|
||||
if user['user_id'] == qq and user['role'] in ('owner', 'admin'):
|
||||
res.append({'id': group_id, 'name': group['group_name']})
|
||||
if user["user_id"] == qq and user["role"] in ("owner", "admin"):
|
||||
res.append({"id": group_id, "name": group["group_name"]})
|
||||
return res
|
||||
|
||||
|
||||
async def auth(token: str):
|
||||
if qq_tuple := token_manager.get_user(token):
|
||||
qq, nickname = qq_tuple
|
||||
bot = nonebot.get_bot()
|
||||
assert(isinstance(bot, Bot))
|
||||
groups = await bot.call_api('get_group_list')
|
||||
assert isinstance(bot, Bot)
|
||||
groups = await bot.call_api("get_group_list")
|
||||
if str(qq) in nonebot.get_driver().config.superusers:
|
||||
jwt_obj = {
|
||||
'id': str(qq),
|
||||
'groups': list(map(
|
||||
lambda info: {'id': info['group_id'], 'name': info['group_name']},
|
||||
groups)),
|
||||
}
|
||||
"id": str(qq),
|
||||
"groups": list(
|
||||
map(
|
||||
lambda info: {
|
||||
"id": info["group_id"],
|
||||
"name": info["group_name"],
|
||||
},
|
||||
groups,
|
||||
)
|
||||
),
|
||||
}
|
||||
ret_obj = {
|
||||
'type': 'admin',
|
||||
'name': nickname,
|
||||
'id': str(qq),
|
||||
'token': pack_jwt(jwt_obj)
|
||||
}
|
||||
return { 'status': 200, **ret_obj }
|
||||
"type": "admin",
|
||||
"name": nickname,
|
||||
"id": str(qq),
|
||||
"token": pack_jwt(jwt_obj),
|
||||
}
|
||||
return {"status": 200, **ret_obj}
|
||||
if admin_groups := await get_admin_groups(int(qq)):
|
||||
jwt_obj = {
|
||||
'id': str(qq),
|
||||
'groups': admin_groups
|
||||
}
|
||||
jwt_obj = {"id": str(qq), "groups": admin_groups}
|
||||
ret_obj = {
|
||||
'type': 'user',
|
||||
'name': nickname,
|
||||
'id': str(qq),
|
||||
'token': pack_jwt(jwt_obj)
|
||||
}
|
||||
return { 'status': 200, **ret_obj }
|
||||
"type": "user",
|
||||
"name": nickname,
|
||||
"id": str(qq),
|
||||
"token": pack_jwt(jwt_obj),
|
||||
}
|
||||
return {"status": 200, **ret_obj}
|
||||
else:
|
||||
return { 'status': 400, 'type': '', 'name': '', 'id': '', 'token': '' }
|
||||
return {"status": 400, "type": "", "name": "", "id": "", "token": ""}
|
||||
else:
|
||||
return { 'status': 400, 'type': '', 'name': '', 'id': '', 'token': '' }
|
||||
return {"status": 400, "type": "", "name": "", "id": "", "token": ""}
|
||||
|
||||
|
||||
async def get_subs_info(jwt_obj: dict):
|
||||
groups = jwt_obj['groups']
|
||||
groups = jwt_obj["groups"]
|
||||
res = {}
|
||||
for group in groups:
|
||||
group_id = group['id']
|
||||
group_id = group["id"]
|
||||
config = Config()
|
||||
subs = list(map(lambda sub: {
|
||||
'platformName': sub['target_type'], 'target': sub['target'], 'targetName': sub['target_name'], 'cats': sub['cats'], 'tags': sub['tags']
|
||||
}, config.list_subscribe(group_id, 'group')))
|
||||
res[group_id] = {
|
||||
'name': group['name'],
|
||||
'subscribes': subs
|
||||
}
|
||||
subs = list(
|
||||
map(
|
||||
lambda sub: {
|
||||
"platformName": sub["target_type"],
|
||||
"target": sub["target"],
|
||||
"targetName": sub["target_name"],
|
||||
"cats": sub["cats"],
|
||||
"tags": sub["tags"],
|
||||
},
|
||||
config.list_subscribe(group_id, "group"),
|
||||
)
|
||||
)
|
||||
res[group_id] = {"name": group["name"], "subscribes": subs}
|
||||
return res
|
||||
|
||||
async def get_target_name(platform_name: str, target: str, jwt_obj: dict):
|
||||
return {'targetName': await check_sub_target(platform_name, target)}
|
||||
|
||||
async def add_group_sub(group_number: str, platform_name: str, target: str,
|
||||
target_name: str, cats: list[int], tags: list[str]):
|
||||
async def get_target_name(platform_name: str, target: str, jwt_obj: dict):
|
||||
return {"targetName": await check_sub_target(platform_name, target)}
|
||||
|
||||
|
||||
async def add_group_sub(
|
||||
group_number: str,
|
||||
platform_name: str,
|
||||
target: str,
|
||||
target_name: str,
|
||||
cats: list[int],
|
||||
tags: list[str],
|
||||
):
|
||||
config = Config()
|
||||
config.add_subscribe(int(group_number), 'group', target, target_name, platform_name, cats, tags)
|
||||
return { 'status': 200, 'msg': '' }
|
||||
config.add_subscribe(
|
||||
int(group_number), "group", target, target_name, platform_name, cats, tags
|
||||
)
|
||||
return {"status": 200, "msg": ""}
|
||||
|
||||
|
||||
async def del_group_sub(group_number: str, platform_name: str, target: str):
|
||||
config = Config()
|
||||
try:
|
||||
config.del_subscribe(int(group_number), 'group', target, platform_name)
|
||||
config.del_subscribe(int(group_number), "group", target, platform_name)
|
||||
except (NoSuchUserException, NoSuchSubscribeException):
|
||||
return { 'status': 400, 'msg': '删除错误' }
|
||||
return { 'status': 200, 'msg': '' }
|
||||
return {"status": 400, "msg": "删除错误"}
|
||||
return {"status": 200, "msg": ""}
|
||||
|
||||
|
||||
async def update_group_sub(group_number: str, platform_name: str, target: str,
|
||||
target_name: str, cats: list[int], tags: list[str]):
|
||||
async def update_group_sub(
|
||||
group_number: str,
|
||||
platform_name: str,
|
||||
target: str,
|
||||
target_name: str,
|
||||
cats: list[int],
|
||||
tags: list[str],
|
||||
):
|
||||
config = Config()
|
||||
try:
|
||||
config.update_subscribe(int(group_number), 'group',
|
||||
target, target_name, platform_name, cats, tags)
|
||||
config.update_subscribe(
|
||||
int(group_number), "group", target, target_name, platform_name, cats, tags
|
||||
)
|
||||
except (NoSuchUserException, NoSuchSubscribeException):
|
||||
return { 'status': 400, 'msg': '更新错误' }
|
||||
return { 'status': 200, 'msg': '' }
|
||||
|
||||
return {"status": 400, "msg": "更新错误"}
|
||||
return {"status": 200, "msg": ""}
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
import datetime
|
||||
import random
|
||||
import string
|
||||
from typing import Optional
|
||||
import jwt
|
||||
import datetime
|
||||
|
||||
_key = ''.join(random.SystemRandom().choice(string.ascii_letters) for _ in range(16))
|
||||
import jwt
|
||||
|
||||
_key = "".join(random.SystemRandom().choice(string.ascii_letters) for _ in range(16))
|
||||
|
||||
|
||||
def pack_jwt(obj: dict) -> str:
|
||||
return jwt.encode(
|
||||
{'exp': datetime.datetime.utcnow() + datetime.timedelta(hours=1), **obj},
|
||||
_key, algorithm='HS256'
|
||||
)
|
||||
{"exp": datetime.datetime.utcnow() + datetime.timedelta(hours=1), **obj},
|
||||
_key,
|
||||
algorithm="HS256",
|
||||
)
|
||||
|
||||
|
||||
def load_jwt(token: str) -> Optional[dict]:
|
||||
try:
|
||||
return jwt.decode(token, _key, algorithms=['HS256'])
|
||||
return jwt.decode(token, _key, algorithms=["HS256"])
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,24 +1,26 @@
|
||||
from typing import Optional
|
||||
from expiringdict import ExpiringDict
|
||||
import random
|
||||
import string
|
||||
import string
|
||||
from typing import Optional
|
||||
|
||||
from expiringdict import ExpiringDict
|
||||
|
||||
|
||||
class TokenManager:
|
||||
|
||||
def __init__(self):
|
||||
self.token_manager = ExpiringDict(max_len=100, max_age_seconds=60*10)
|
||||
self.token_manager = ExpiringDict(max_len=100, max_age_seconds=60 * 10)
|
||||
|
||||
def get_user(self, token: str) -> Optional[tuple]:
|
||||
res = self.token_manager.get(token)
|
||||
assert(res is None or isinstance(res, tuple))
|
||||
assert res is None or isinstance(res, tuple)
|
||||
return res
|
||||
|
||||
def save_user(self, token: str, qq: tuple) -> None:
|
||||
self.token_manager[token] = qq
|
||||
|
||||
def get_user_token(self, qq: tuple) -> str:
|
||||
token = ''.join(random.choices(string.ascii_letters + string.digits, k=16))
|
||||
token = "".join(random.choices(string.ascii_letters + string.digits, k=16))
|
||||
self.save_user(token, qq)
|
||||
return token
|
||||
|
||||
|
||||
token_manager = TokenManager()
|
||||
|
||||
Reference in New Issue
Block a user