Merge pull request #115 from felinae98/fix-dup-sub

处理「添加重复订阅」异常
This commit is contained in:
felinae98 2022-10-09 19:56:18 +08:00 committed by GitHub
commit 500f8676c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 167 additions and 800 deletions

841
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -7,6 +7,7 @@ from ..config import (
NoSuchUserException, NoSuchUserException,
config, config,
) )
from ..config.db_config import SubscribeDupException
from ..platform import check_sub_target, platform_manager from ..platform import check_sub_target, platform_manager
from ..types import Target as T_Target from ..types import Target as T_Target
from ..types import WeightConfig from ..types import WeightConfig
@ -120,6 +121,7 @@ async def add_group_sub(
cats: list[int], cats: list[int],
tags: list[str], tags: list[str],
): ):
try:
await config.add_subscribe( await config.add_subscribe(
int(group_number), int(group_number),
"group", "group",
@ -130,6 +132,8 @@ async def add_group_sub(
tags, tags,
) )
return {"status": 200, "msg": ""} return {"status": 200, "msg": ""}
except SubscribeDupException:
return {"status": 403, "msg": ""}
async def del_group_sub(group_number: int, platform_name: str, target: str): async def del_group_sub(group_number: int, platform_name: str, target: str):

View File

@ -1,9 +1,9 @@
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, time from datetime import datetime, time
from typing import Any, Awaitable, Callable, Optional from typing import Awaitable, Callable, Optional
from nonebot_plugin_datastore.db import get_engine from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from sqlalchemy.sql.expression import delete, select from sqlalchemy.sql.expression import delete, select
@ -24,6 +24,10 @@ def _get_time():
return cur_time return cur_time
class SubscribeDupException(Exception):
...
class DBConfig: class DBConfig:
def __init__(self): def __init__(self):
self.add_target_hook: Optional[Callable[[str, T_Target], Awaitable]] = None self.add_target_hook: Optional[Callable[[str, T_Target], Awaitable]] = None
@ -74,7 +78,12 @@ class DBConfig:
target=db_target, target=db_target,
) )
session.add(subscribe) session.add(subscribe)
try:
await session.commit() await session.commit()
except IntegrityError as e:
if len(e.args) > 0 and "UNIQUE constraint failed" in e.args[0]:
raise SubscribeDupException()
raise e
async def list_subscribe(self, user: int, user_type: str) -> list[Subscribe]: async def list_subscribe(self, user: int, user_type: str) -> list[Subscribe]:
async with AsyncSession(get_engine()) as session: async with AsyncSession(get_engine()) as session:

View File

@ -16,6 +16,7 @@ from nonebot.rule import to_me
from nonebot.typing import T_State from nonebot.typing import T_State
from .config import config from .config import config
from .config.db_config import SubscribeDupException
from .platform import Platform, check_sub_target, platform_manager from .platform import Platform, check_sub_target, platform_manager
from .plugin_config import plugin_config from .plugin_config import plugin_config
from .types import Category, Target, User from .types import Category, Target, User
@ -202,6 +203,7 @@ def do_add_sub(add_sub: Type[Matcher]):
async def add_sub_process(event: Event, state: T_State): async def add_sub_process(event: Event, state: T_State):
user = cast(User, state.get("target_user_info")) user = cast(User, state.get("target_user_info"))
assert isinstance(user, User) assert isinstance(user, User)
try:
await config.add_subscribe( await config.add_subscribe(
# state.get("_user_id") or event.group_id, # state.get("_user_id") or event.group_id,
# user_type="group", # user_type="group",
@ -213,6 +215,10 @@ def do_add_sub(add_sub: Type[Matcher]):
cats=state.get("cats", []), cats=state.get("cats", []),
tags=state.get("tags", []), tags=state.get("tags", []),
) )
except SubscribeDupException:
await add_sub.finish(f"添加 {state['name']} 失败: 已存在该订阅")
except Exception as e:
await add_sub.finish(f"添加 {state['name']} 失败: {e}")
await add_sub.finish("添加 {} 成功".format(state["name"])) await add_sub.finish("添加 {} 成功".format(state["name"]))

View File

@ -1,3 +1,4 @@
import pytest
from nonebug.app import App from nonebug.app import App
from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.functions import func from sqlalchemy.sql.functions import func
@ -72,6 +73,33 @@ async def test_add_subscribe(app: App, init_scheduler):
assert conf.tags == ["tag"] assert conf.tags == ["tag"]
async def test_add_dup_sub(init_scheduler):
from nonebot_bison.config.db_config import SubscribeDupException, config
from nonebot_bison.types import Target as TTarget
await config.add_subscribe(
user=123,
user_type="group",
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
with pytest.raises(SubscribeDupException):
await config.add_subscribe(
user=123,
user_type="group",
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
async def test_del_subsribe(init_scheduler): async def test_del_subsribe(init_scheduler):
from nonebot_bison.config.db_config import config from nonebot_bison.config.db_config import config
from nonebot_bison.config.db_model import Subscribe, Target, User from nonebot_bison.config.db_model import Subscribe, Target, User

View File

@ -28,34 +28,21 @@ def ms_list():
@pytest.fixture @pytest.fixture
def pic_hash(): def expected_md():
platform_name = platform.system()
if platform_name == "Windows":
return "58723fdc24b473b6dbd8ec8cbc3b7e46160c83df"
elif platform_name == "Linux":
return "4d540798108762df76de34f7bdbc667dada6b5cb"
elif platform_name == "Darwin":
return "a482bf8317d56e5ddc71437584343ace29ff545c"
else:
raise UnsupportedOperation(f"未支持的平台{platform_name}")
@pytest.fixture
def expect_md():
return "【Zc】每早合约日替攻略<br>![Image](http://i0.hdslb.com/bfs/live/new_room_cover/cf7d4d3b2f336c6dba299644c3af952c5db82612.jpg)\n来源: Bilibili直播 魔法Zc目录<br>详情: https://live.bilibili.com/3044248<br>" return "【Zc】每早合约日替攻略<br>![Image](http://i0.hdslb.com/bfs/live/new_room_cover/cf7d4d3b2f336c6dba299644c3af952c5db82612.jpg)\n来源: Bilibili直播 魔法Zc目录<br>详情: https://live.bilibili.com/3044248<br>"
def test_gene_md(app: App, expect_md, ms_list): def test_gene_md(app: App, expected_md, ms_list):
from nonebot_bison.post.custom_post import CustomPost from nonebot_bison.post.custom_post import CustomPost
cp = CustomPost(message_segments=ms_list) cp = CustomPost(message_segments=ms_list)
cp_md = cp._generate_md() cp_md = cp._generate_md()
assert cp_md == expect_md assert cp_md == expected_md
@respx.mock @respx.mock
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gene_pic(app: App, ms_list, pic_hash): async def test_gene_pic(app: App, ms_list, expected_md):
from nonebot_bison.post.custom_post import CustomPost from nonebot_bison.post.custom_post import CustomPost
pic_router = respx.get( pic_router = respx.get(
@ -69,12 +56,6 @@ async def test_gene_pic(app: App, ms_list, pic_hash):
pic_router.mock(return_value=Response(200, stream=mock_pic)) pic_router.mock(return_value=Response(200, stream=mock_pic))
cp = CustomPost(message_segments=ms_list) cp = CustomPost(message_segments=ms_list)
cp_pic_bytes: list[MessageSegment] = await cp.generate_pic_messages() cp_pic_msg_md: str = cp._generate_md()
pure_b64 = base64.b64decode( assert cp_pic_msg_md == expected_md
cp_pic_bytes[0].data.get("file").replace("base64://", "")
)
sha1obj = hashlib.sha1()
sha1obj.update(pure_b64)
sha1hash = sha1obj.hexdigest()
assert sha1hash == pic_hash