mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-07 20:33:01 +08:00
commit
500f8676c2
841
poetry.lock
generated
841
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -7,6 +7,7 @@ from ..config import (
|
|||||||
NoSuchUserException,
|
NoSuchUserException,
|
||||||
config,
|
config,
|
||||||
)
|
)
|
||||||
|
from ..config.db_config import SubscribeDupException
|
||||||
from ..platform import check_sub_target, platform_manager
|
from ..platform import check_sub_target, platform_manager
|
||||||
from ..types import Target as T_Target
|
from ..types import Target as T_Target
|
||||||
from ..types import WeightConfig
|
from ..types import WeightConfig
|
||||||
@ -120,6 +121,7 @@ async def add_group_sub(
|
|||||||
cats: list[int],
|
cats: list[int],
|
||||||
tags: list[str],
|
tags: list[str],
|
||||||
):
|
):
|
||||||
|
try:
|
||||||
await config.add_subscribe(
|
await config.add_subscribe(
|
||||||
int(group_number),
|
int(group_number),
|
||||||
"group",
|
"group",
|
||||||
@ -130,6 +132,8 @@ async def add_group_sub(
|
|||||||
tags,
|
tags,
|
||||||
)
|
)
|
||||||
return {"status": 200, "msg": ""}
|
return {"status": 200, "msg": ""}
|
||||||
|
except SubscribeDupException:
|
||||||
|
return {"status": 403, "msg": ""}
|
||||||
|
|
||||||
|
|
||||||
async def del_group_sub(group_number: int, platform_name: str, target: str):
|
async def del_group_sub(group_number: int, platform_name: str, target: str):
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, time
|
from datetime import datetime, time
|
||||||
from typing import Any, Awaitable, Callable, Optional
|
from typing import Awaitable, Callable, Optional
|
||||||
|
|
||||||
from nonebot_plugin_datastore.db import get_engine
|
from nonebot_plugin_datastore.db import get_engine
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
from sqlalchemy.sql.expression import delete, select
|
from sqlalchemy.sql.expression import delete, select
|
||||||
@ -24,6 +24,10 @@ def _get_time():
|
|||||||
return cur_time
|
return cur_time
|
||||||
|
|
||||||
|
|
||||||
|
class SubscribeDupException(Exception):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class DBConfig:
|
class DBConfig:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.add_target_hook: Optional[Callable[[str, T_Target], Awaitable]] = None
|
self.add_target_hook: Optional[Callable[[str, T_Target], Awaitable]] = None
|
||||||
@ -74,7 +78,12 @@ class DBConfig:
|
|||||||
target=db_target,
|
target=db_target,
|
||||||
)
|
)
|
||||||
session.add(subscribe)
|
session.add(subscribe)
|
||||||
|
try:
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
except IntegrityError as e:
|
||||||
|
if len(e.args) > 0 and "UNIQUE constraint failed" in e.args[0]:
|
||||||
|
raise SubscribeDupException()
|
||||||
|
raise e
|
||||||
|
|
||||||
async def list_subscribe(self, user: int, user_type: str) -> list[Subscribe]:
|
async def list_subscribe(self, user: int, user_type: str) -> list[Subscribe]:
|
||||||
async with AsyncSession(get_engine()) as session:
|
async with AsyncSession(get_engine()) as session:
|
||||||
|
@ -16,6 +16,7 @@ from nonebot.rule import to_me
|
|||||||
from nonebot.typing import T_State
|
from nonebot.typing import T_State
|
||||||
|
|
||||||
from .config import config
|
from .config import config
|
||||||
|
from .config.db_config import SubscribeDupException
|
||||||
from .platform import Platform, check_sub_target, platform_manager
|
from .platform import Platform, check_sub_target, platform_manager
|
||||||
from .plugin_config import plugin_config
|
from .plugin_config import plugin_config
|
||||||
from .types import Category, Target, User
|
from .types import Category, Target, User
|
||||||
@ -202,6 +203,7 @@ def do_add_sub(add_sub: Type[Matcher]):
|
|||||||
async def add_sub_process(event: Event, state: T_State):
|
async def add_sub_process(event: Event, state: T_State):
|
||||||
user = cast(User, state.get("target_user_info"))
|
user = cast(User, state.get("target_user_info"))
|
||||||
assert isinstance(user, User)
|
assert isinstance(user, User)
|
||||||
|
try:
|
||||||
await config.add_subscribe(
|
await config.add_subscribe(
|
||||||
# state.get("_user_id") or event.group_id,
|
# state.get("_user_id") or event.group_id,
|
||||||
# user_type="group",
|
# user_type="group",
|
||||||
@ -213,6 +215,10 @@ def do_add_sub(add_sub: Type[Matcher]):
|
|||||||
cats=state.get("cats", []),
|
cats=state.get("cats", []),
|
||||||
tags=state.get("tags", []),
|
tags=state.get("tags", []),
|
||||||
)
|
)
|
||||||
|
except SubscribeDupException:
|
||||||
|
await add_sub.finish(f"添加 {state['name']} 失败: 已存在该订阅")
|
||||||
|
except Exception as e:
|
||||||
|
await add_sub.finish(f"添加 {state['name']} 失败: {e}")
|
||||||
await add_sub.finish("添加 {} 成功".format(state["name"]))
|
await add_sub.finish("添加 {} 成功".format(state["name"]))
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import pytest
|
||||||
from nonebug.app import App
|
from nonebug.app import App
|
||||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||||
from sqlalchemy.sql.functions import func
|
from sqlalchemy.sql.functions import func
|
||||||
@ -72,6 +73,33 @@ async def test_add_subscribe(app: App, init_scheduler):
|
|||||||
assert conf.tags == ["tag"]
|
assert conf.tags == ["tag"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_dup_sub(init_scheduler):
|
||||||
|
|
||||||
|
from nonebot_bison.config.db_config import SubscribeDupException, config
|
||||||
|
from nonebot_bison.types import Target as TTarget
|
||||||
|
|
||||||
|
await config.add_subscribe(
|
||||||
|
user=123,
|
||||||
|
user_type="group",
|
||||||
|
target=TTarget("weibo_id"),
|
||||||
|
target_name="weibo_name",
|
||||||
|
platform_name="weibo",
|
||||||
|
cats=[],
|
||||||
|
tags=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SubscribeDupException):
|
||||||
|
await config.add_subscribe(
|
||||||
|
user=123,
|
||||||
|
user_type="group",
|
||||||
|
target=TTarget("weibo_id"),
|
||||||
|
target_name="weibo_name",
|
||||||
|
platform_name="weibo",
|
||||||
|
cats=[],
|
||||||
|
tags=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_del_subsribe(init_scheduler):
|
async def test_del_subsribe(init_scheduler):
|
||||||
from nonebot_bison.config.db_config import config
|
from nonebot_bison.config.db_config import config
|
||||||
from nonebot_bison.config.db_model import Subscribe, Target, User
|
from nonebot_bison.config.db_model import Subscribe, Target, User
|
||||||
|
@ -28,34 +28,21 @@ def ms_list():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def pic_hash():
|
def expected_md():
|
||||||
platform_name = platform.system()
|
|
||||||
if platform_name == "Windows":
|
|
||||||
return "58723fdc24b473b6dbd8ec8cbc3b7e46160c83df"
|
|
||||||
elif platform_name == "Linux":
|
|
||||||
return "4d540798108762df76de34f7bdbc667dada6b5cb"
|
|
||||||
elif platform_name == "Darwin":
|
|
||||||
return "a482bf8317d56e5ddc71437584343ace29ff545c"
|
|
||||||
else:
|
|
||||||
raise UnsupportedOperation(f"未支持的平台{platform_name}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def expect_md():
|
|
||||||
return "【Zc】每早合约日替攻略!<br>\n来源: Bilibili直播 魔法Zc目录<br>详情: https://live.bilibili.com/3044248<br>"
|
return "【Zc】每早合约日替攻略!<br>\n来源: Bilibili直播 魔法Zc目录<br>详情: https://live.bilibili.com/3044248<br>"
|
||||||
|
|
||||||
|
|
||||||
def test_gene_md(app: App, expect_md, ms_list):
|
def test_gene_md(app: App, expected_md, ms_list):
|
||||||
from nonebot_bison.post.custom_post import CustomPost
|
from nonebot_bison.post.custom_post import CustomPost
|
||||||
|
|
||||||
cp = CustomPost(message_segments=ms_list)
|
cp = CustomPost(message_segments=ms_list)
|
||||||
cp_md = cp._generate_md()
|
cp_md = cp._generate_md()
|
||||||
assert cp_md == expect_md
|
assert cp_md == expected_md
|
||||||
|
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gene_pic(app: App, ms_list, pic_hash):
|
async def test_gene_pic(app: App, ms_list, expected_md):
|
||||||
from nonebot_bison.post.custom_post import CustomPost
|
from nonebot_bison.post.custom_post import CustomPost
|
||||||
|
|
||||||
pic_router = respx.get(
|
pic_router = respx.get(
|
||||||
@ -69,12 +56,6 @@ async def test_gene_pic(app: App, ms_list, pic_hash):
|
|||||||
pic_router.mock(return_value=Response(200, stream=mock_pic))
|
pic_router.mock(return_value=Response(200, stream=mock_pic))
|
||||||
|
|
||||||
cp = CustomPost(message_segments=ms_list)
|
cp = CustomPost(message_segments=ms_list)
|
||||||
cp_pic_bytes: list[MessageSegment] = await cp.generate_pic_messages()
|
cp_pic_msg_md: str = cp._generate_md()
|
||||||
|
|
||||||
pure_b64 = base64.b64decode(
|
assert cp_pic_msg_md == expected_md
|
||||||
cp_pic_bytes[0].data.get("file").replace("base64://", "")
|
|
||||||
)
|
|
||||||
sha1obj = hashlib.sha1()
|
|
||||||
sha1obj.update(pure_b64)
|
|
||||||
sha1hash = sha1obj.hexdigest()
|
|
||||||
assert sha1hash == pic_hash
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user