mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-07 12:23:00 +08:00
164 lines
5.3 KiB
Python
164 lines
5.3 KiB
Python
import os
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
from typing import Union
|
||
|
||
import socketio
|
||
from fastapi.staticfiles import StaticFiles
|
||
from nonebot import get_driver, on_command
|
||
from nonebot.adapters.onebot.v11 import Bot
|
||
from nonebot.adapters.onebot.v11.event import GroupMessageEvent, PrivateMessageEvent
|
||
from nonebot.drivers.fastapi import Driver
|
||
from nonebot.log import logger
|
||
from nonebot.rule import to_me
|
||
from nonebot.typing import T_State
|
||
|
||
from ..plugin_config import plugin_config
|
||
from .api import (
|
||
add_group_sub,
|
||
auth,
|
||
del_group_sub,
|
||
get_global_conf,
|
||
get_subs_info,
|
||
get_target_name,
|
||
test,
|
||
update_group_sub,
|
||
)
|
||
from .jwt import load_jwt
|
||
from .token_manager import token_manager as tm
|
||
|
||
URL_BASE = "/bison/"
|
||
GLOBAL_CONF_URL = f"{URL_BASE}api/global_conf"
|
||
AUTH_URL = f"{URL_BASE}api/auth"
|
||
SUBSCRIBE_URL = f"{URL_BASE}api/subs"
|
||
GET_TARGET_NAME_URL = f"{URL_BASE}api/target_name"
|
||
TEST_URL = f"{URL_BASE}test"
|
||
|
||
STATIC_PATH = (Path(__file__).parent / "dist").resolve()
|
||
|
||
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||
socket_app = socketio.ASGIApp(sio, socketio_path="socket")
|
||
|
||
|
||
class SinglePageApplication(StaticFiles):
|
||
def __init__(self, directory: os.PathLike, index="index.html"):
|
||
self.index = index
|
||
super().__init__(directory=directory, packages=None, html=True, check_dir=True)
|
||
|
||
def lookup_path(self, path: str) -> tuple[str, Union[os.stat_result, None]]:
|
||
full_path, stat_res = super().lookup_path(path)
|
||
if stat_res is None:
|
||
return super().lookup_path(self.index)
|
||
return (full_path, stat_res)
|
||
|
||
|
||
def register_router_fastapi(driver: Driver, socketio):
|
||
from fastapi import HTTPException, status
|
||
from fastapi.param_functions import Depends
|
||
from fastapi.security import OAuth2PasswordBearer
|
||
|
||
oath_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||
|
||
async def get_jwt_obj(token: str = Depends(oath_scheme)):
|
||
obj = load_jwt(token)
|
||
if not obj:
|
||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||
return obj
|
||
|
||
async def check_group_permission(
|
||
groupNumber: str, token_obj: dict = Depends(get_jwt_obj)
|
||
):
|
||
groups = token_obj["groups"]
|
||
for group in groups:
|
||
if int(groupNumber) == group["id"]:
|
||
return
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||
|
||
@dataclass
|
||
class AddSubscribeReq:
|
||
platformName: str
|
||
target: str
|
||
targetName: str
|
||
cats: list[int]
|
||
tags: list[str]
|
||
|
||
app = driver.server_app
|
||
static_path = STATIC_PATH
|
||
app.get(TEST_URL)(test)
|
||
app.get(GLOBAL_CONF_URL)(get_global_conf)
|
||
app.get(AUTH_URL)(auth)
|
||
|
||
@app.get(SUBSCRIBE_URL)
|
||
async def subs(jwt_obj: dict = Depends(get_jwt_obj)):
|
||
return await get_subs_info(jwt_obj)
|
||
|
||
@app.get(GET_TARGET_NAME_URL)
|
||
async def _get_target_name(
|
||
platformName: str, target: str, jwt_obj: dict = Depends(get_jwt_obj)
|
||
):
|
||
return await get_target_name(platformName, target, jwt_obj)
|
||
|
||
@app.post(SUBSCRIBE_URL, dependencies=[Depends(check_group_permission)])
|
||
async def _add_group_subs(groupNumber: str, req: AddSubscribeReq):
|
||
return await add_group_sub(
|
||
group_number=groupNumber,
|
||
platform_name=req.platformName,
|
||
target=req.target,
|
||
target_name=req.targetName,
|
||
cats=req.cats,
|
||
tags=req.tags,
|
||
)
|
||
|
||
@app.patch(SUBSCRIBE_URL, dependencies=[Depends(check_group_permission)])
|
||
async def _update_group_subs(groupNumber: str, req: AddSubscribeReq):
|
||
return await update_group_sub(
|
||
group_number=groupNumber,
|
||
platform_name=req.platformName,
|
||
target=req.target,
|
||
target_name=req.targetName,
|
||
cats=req.cats,
|
||
tags=req.tags,
|
||
)
|
||
|
||
@app.delete(SUBSCRIBE_URL, dependencies=[Depends(check_group_permission)])
|
||
async def _del_group_subs(groupNumber: str, target: str, platformName: str):
|
||
return await del_group_sub(groupNumber, platformName, target)
|
||
|
||
app.mount(URL_BASE, SinglePageApplication(directory=static_path), name="bison")
|
||
|
||
|
||
def init():
|
||
driver = get_driver()
|
||
if "fastapi" in driver.type:
|
||
assert isinstance(driver, Driver)
|
||
register_router_fastapi(driver, socket_app)
|
||
else:
|
||
logger.warning(f"Driver {driver.type} not supported")
|
||
return
|
||
host = str(driver.config.host)
|
||
port = driver.config.port
|
||
if host in ["0.0.0.0", "127.0.0.1"]:
|
||
host = "localhost"
|
||
logger.opt(colors=True).info(
|
||
f"Nonebot test frontend will be running at: "
|
||
f"<b><u>http://{host}:{port}{URL_BASE}</u></b>"
|
||
)
|
||
|
||
|
||
if (STATIC_PATH / "index.html").exists():
|
||
init()
|
||
|
||
get_token = on_command("后台管理", rule=to_me(), priority=5)
|
||
|
||
@get_token.handle()
|
||
async def send_token(bot: "Bot", event: PrivateMessageEvent, state: T_State):
|
||
token = tm.get_user_token((event.get_user_id(), event.sender.nickname))
|
||
await get_token.finish(f"请访问: {plugin_config.bison_outer_url}auth/{token}")
|
||
|
||
get_token.__help__name__ = "获取后台管理地址"
|
||
get_token.__help__info__ = "获取管理bot后台的地址,该地址会" "在一段时间过后过期,请不要泄漏该地址"
|
||
else:
|
||
logger.warning(
|
||
"Frontend file not found, please compile it or use docker or pypi version"
|
||
)
|