diff --git a/src/plugins/nonebot_bison/admin_page/api.py b/src/plugins/nonebot_bison/admin_page/api.py index 2a1895b..06ec05e 100644 --- a/src/plugins/nonebot_bison/admin_page/api.py +++ b/src/plugins/nonebot_bison/admin_page/api.py @@ -10,6 +10,7 @@ from ..config import ( from ..platform import check_sub_target, platform_manager from ..types import Target as T_Target from ..types import WeightConfig +from ..config.db_config import SubscribeDupException from .jwt import pack_jwt from .token_manager import token_manager @@ -120,16 +121,19 @@ async def add_group_sub( cats: list[int], tags: list[str], ): - await config.add_subscribe( - int(group_number), - "group", - T_Target(target), - target_name, - platform_name, - cats, - tags, - ) - return {"status": 200, "msg": ""} + try: + await config.add_subscribe( + int(group_number), + "group", + T_Target(target), + target_name, + platform_name, + cats, + tags, + ) + return {"status": 200, "msg": ""} + except SubscribeDupException: + return {"status": 403, "msg": ""} async def del_group_sub(group_number: int, platform_name: str, target: str): diff --git a/src/plugins/nonebot_bison/config/db_config.py b/src/plugins/nonebot_bison/config/db_config.py index 90efcc7..d7b3150 100644 --- a/src/plugins/nonebot_bison/config/db_config.py +++ b/src/plugins/nonebot_bison/config/db_config.py @@ -1,9 +1,9 @@ from collections import defaultdict -from dataclasses import dataclass from datetime import datetime, time -from typing import Any, Awaitable, Callable, Optional +from typing import Awaitable, Callable, Optional from nonebot_plugin_datastore.db import get_engine +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.orm import selectinload from sqlalchemy.sql.expression import delete, select @@ -23,6 +23,8 @@ def _get_time(): cur_time = time(hour=dt.hour, minute=dt.minute, second=dt.second) return cur_time +class SubscribeDupException(Exception): + ... class DBConfig: def __init__(self): @@ -74,7 +76,12 @@ class DBConfig: target=db_target, ) session.add(subscribe) - await session.commit() + try: + await session.commit() + except IntegrityError as e: + if len(e.args) > 0 and 'UNIQUE constraint failed' in e.args[0]: + raise SubscribeDupException() + raise e async def list_subscribe(self, user: int, user_type: str) -> list[Subscribe]: async with AsyncSession(get_engine()) as session: diff --git a/src/plugins/nonebot_bison/config_manager.py b/src/plugins/nonebot_bison/config_manager.py index a3b71c2..f81dcaa 100644 --- a/src/plugins/nonebot_bison/config_manager.py +++ b/src/plugins/nonebot_bison/config_manager.py @@ -16,6 +16,7 @@ from nonebot.rule import to_me from nonebot.typing import T_State from .config import config +from .config.db_config import SubscribeDupException from .platform import Platform, check_sub_target, platform_manager from .plugin_config import plugin_config from .types import Category, Target, User @@ -202,18 +203,23 @@ def do_add_sub(add_sub: Type[Matcher]): async def add_sub_process(event: Event, state: T_State): user = cast(User, state.get("target_user_info")) assert isinstance(user, User) - await config.add_subscribe( - # state.get("_user_id") or event.group_id, - # user_type="group", - user=user.user, - user_type=user.user_type, - target=state["id"], - target_name=state["name"], - platform_name=state["platform"], - cats=state.get("cats", []), - tags=state.get("tags", []), - ) - await add_sub.finish("添加 {} 成功".format(state["name"])) + try: + await config.add_subscribe( + # state.get("_user_id") or event.group_id, + # user_type="group", + user=user.user, + user_type=user.user_type, + target=state["id"], + target_name=state["name"], + platform_name=state["platform"], + cats=state.get("cats", []), + tags=state.get("tags", []), + ) + await add_sub.finish("添加 {} 成功".format(state["name"])) + except SubscribeDupException: + await add_sub.finish(f"添加 {state['name']} 失败: 已存在该订阅") + except Exception as e: + await add_sub.finish(f"添加 {state['name']} 失败: {e}") def do_query_sub(query_sub: Type[Matcher]): diff --git a/tests/config/test_config_operation.py b/tests/config/test_config_operation.py index b42dbf7..1e7cecd 100644 --- a/tests/config/test_config_operation.py +++ b/tests/config/test_config_operation.py @@ -1,4 +1,5 @@ from nonebug.app import App +import pytest from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.sql.functions import func from sqlmodel.sql.expression import select @@ -71,6 +72,32 @@ async def test_add_subscribe(app: App, init_scheduler): assert conf.categories == [1] assert conf.tags == ["tag"] +async def test_add_dup_sub(init_scheduler): + + from nonebot_bison.config.db_config import config + from nonebot_bison.types import Target as TTarget + from nonebot_bison.config.db_config import SubscribeDupException + + await config.add_subscribe( + user=123, + user_type="group", + target=TTarget("weibo_id"), + target_name="weibo_name", + platform_name="weibo", + cats=[], + tags=[], + ) + + with pytest.raises(SubscribeDupException): + await config.add_subscribe( + user=123, + user_type="group", + target=TTarget("weibo_id"), + target_name="weibo_name", + platform_name="weibo", + cats=[], + tags=[], + ) async def test_del_subsribe(init_scheduler): from nonebot_bison.config.db_config import config