diff --git a/src/plugins/nonebot_bison/platform/platform.py b/src/plugins/nonebot_bison/platform/platform.py index 8d7d3be..eb2f759 100644 --- a/src/plugins/nonebot_bison/platform/platform.py +++ b/src/plugins/nonebot_bison/platform/platform.py @@ -39,9 +39,11 @@ class RegistryMeta(type): class PlatformMeta(RegistryMeta): categories: dict[Category, str] + store: dict[Target, Any] def __init__(cls, name, bases, namespace, **kwargs): cls.reverse_category = {} + cls.store = {} if hasattr(cls, "categories") and cls.categories: for key, val in cls.categories.items(): cls.reverse_category[val] = key @@ -61,7 +63,6 @@ class Platform(metaclass=PlatformABCMeta, base=True): has_target: bool categories: dict[Category, str] enable_tag: bool - store: dict[Target, Any] platform_name: str parse_target_promot: Optional[str] = None registry: list[Type["Platform"]] @@ -110,7 +111,6 @@ class Platform(metaclass=PlatformABCMeta, base=True): def __init__(self, client: AsyncClient): super().__init__() - self.store = dict() self.client = client class ParseTargetException(Exception): @@ -124,11 +124,13 @@ class Platform(metaclass=PlatformABCMeta, base=True): def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]: "Return Tag list of given RawPost" - def get_stored_data(self, target: Target) -> Any: - return self.store.get(target) + @classmethod + def get_stored_data(cls, target: Target) -> Any: + return cls.store.get(target) - def set_stored_data(self, target: Target, data: Any): - self.store[target] = data + @classmethod + def set_stored_data(cls, target: Target, data: Any): + cls.store[target] = data def tag_separator(self, stored_tags: list[Tag]) -> tuple[list[Tag], list[Tag]]: """返回分离好的正反tag元组""" diff --git a/tests/platforms/test_platform.py b/tests/platforms/test_platform.py index b940160..27f37c1 100644 --- a/tests/platforms/test_platform.py +++ b/tests/platforms/test_platform.py @@ -2,6 +2,7 @@ from time import time from typing import Any, Optional import pytest +from httpx import AsyncClient from nonebug.app import App now = time() @@ -53,8 +54,9 @@ def mock_platform_without_cats_tags(app: App): categories = {} has_target = True + sub_index = 0 + def __init__(self, client): - self.sub_index = 0 super().__init__(client) @classmethod @@ -75,14 +77,15 @@ def mock_platform_without_cats_tags(app: App): target_name="Mock", ) - async def get_sub_list(self, _: "Target"): - if self.sub_index == 0: - self.sub_index += 1 + @classmethod + async def get_sub_list(cls, _: "Target"): + if cls.sub_index == 0: + cls.sub_index += 1 return raw_post_list_1 else: return raw_post_list_2 - return MockPlatform(None) + return MockPlatform @pytest.fixture @@ -112,8 +115,9 @@ def mock_platform(app: App): Category(2): "视频", } + sub_index = 0 + def __init__(self, client): - self.sub_index = 0 super().__init__(client) @staticmethod @@ -140,14 +144,15 @@ def mock_platform(app: App): target_name="Mock", ) - async def get_sub_list(self, _: "Target"): - if self.sub_index == 0: - self.sub_index += 1 + @classmethod + async def get_sub_list(cls, _: "Target"): + if cls.sub_index == 0: + cls.sub_index += 1 return raw_post_list_1 else: return raw_post_list_2 - return MockPlatform(None) + return MockPlatform @pytest.fixture @@ -180,8 +185,9 @@ def mock_platform_no_target(app: App, mock_scheduler_conf): has_target = False categories = {Category(1): "转发", Category(2): "视频", Category(3): "不支持"} + sub_index = 0 + def __init__(self, client): - self.sub_index = 0 super().__init__(client) @staticmethod @@ -210,9 +216,10 @@ def mock_platform_no_target(app: App, mock_scheduler_conf): target_name="Mock", ) - async def get_sub_list(self, _: "Target"): - if self.sub_index == 0: - self.sub_index += 1 + @classmethod + async def get_sub_list(cls, _: "Target"): + if cls.sub_index == 0: + cls.sub_index += 1 return raw_post_list_1 else: return raw_post_list_2 @@ -241,8 +248,9 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf): Category(5): "leixing5", } + sub_index = 0 + def __init__(self, client): - self.sub_index = 0 super().__init__(client) @classmethod @@ -269,7 +277,8 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf): target_name="Mock", ) - async def get_sub_list(self, _: "Target"): + @classmethod + async def get_sub_list(cls, _: "Target"): list_1 = [ {"id": 5, "text": "p5", "date": now, "tags": ["tag1"], "category": 4} ] @@ -278,8 +287,8 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf): {"id": 6, "text": "p6", "date": now, "tags": ["tag1"], "category": 4}, {"id": 7, "text": "p7", "date": now, "tags": ["tag2"], "category": 5}, ] - if self.sub_index == 0: - self.sub_index += 1 + if cls.sub_index == 0: + cls.sub_index += 1 return list_1 else: return list_2 @@ -308,16 +317,18 @@ def mock_status_change(app: App): Category(2): "视频", } + sub_index = 0 + def __init__(self, client): - self.sub_index = 0 super().__init__(client) - async def get_status(self, _: "Target"): - if self.sub_index == 0: - self.sub_index += 1 + @classmethod + async def get_status(cls, _: "Target"): + if cls.sub_index == 0: + cls.sub_index += 1 return {"s": False} - elif self.sub_index == 1: - self.sub_index += 1 + elif cls.sub_index == 1: + cls.sub_index += 1 return {"s": True} else: return {"s": False} @@ -335,18 +346,18 @@ def mock_status_change(app: App): def get_category(self, raw_post): return raw_post["cat"] - return MockPlatform(None) + return MockPlatform @pytest.mark.asyncio async def test_new_message_target_without_cats_tags( mock_platform_without_cats_tags, user_info_factory ): - res1 = await mock_platform_without_cats_tags.fetch_new_post( + res1 = await mock_platform_without_cats_tags(AsyncClient()).fetch_new_post( "dummy", [user_info_factory([1, 2], [])] ) assert len(res1) == 0 - res2 = await mock_platform_without_cats_tags.fetch_new_post( + res2 = await mock_platform_without_cats_tags(AsyncClient()).fetch_new_post( "dummy", [ user_info_factory([], []), @@ -361,9 +372,11 @@ async def test_new_message_target_without_cats_tags( @pytest.mark.asyncio async def test_new_message_target(mock_platform, user_info_factory): - res1 = await mock_platform.fetch_new_post("dummy", [user_info_factory([1, 2], [])]) + res1 = await mock_platform(AsyncClient()).fetch_new_post( + "dummy", [user_info_factory([1, 2], [])] + ) assert len(res1) == 0 - res2 = await mock_platform.fetch_new_post( + res2 = await mock_platform(AsyncClient()).fetch_new_post( "dummy", [ user_info_factory([1, 2], []), @@ -388,12 +401,11 @@ async def test_new_message_target(mock_platform, user_info_factory): @pytest.mark.asyncio async def test_new_message_no_target(mock_platform_no_target, user_info_factory): - mock_platform_no_target = mock_platform_no_target(None) - res1 = await mock_platform_no_target.fetch_new_post( + res1 = await mock_platform_no_target(AsyncClient()).fetch_new_post( "dummy", [user_info_factory([1, 2], [])] ) assert len(res1) == 0 - res2 = await mock_platform_no_target.fetch_new_post( + res2 = await mock_platform_no_target(AsyncClient()).fetch_new_post( "dummy", [ user_info_factory([1, 2], []), @@ -414,7 +426,7 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory) assert "p2" in id_set_1 and "p3" in id_set_1 assert "p2" in id_set_2 assert "p2" in id_set_3 - res3 = await mock_platform_no_target.fetch_new_post( + res3 = await mock_platform_no_target(AsyncClient()).fetch_new_post( "dummy", [user_info_factory([1, 2], [])] ) assert len(res3) == 0 @@ -422,18 +434,18 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory) @pytest.mark.asyncio async def test_status_change(mock_status_change, user_info_factory): - res1 = await mock_status_change.fetch_new_post( + res1 = await mock_status_change(AsyncClient()).fetch_new_post( "dummy", [user_info_factory([1, 2], [])] ) assert len(res1) == 0 - res2 = await mock_status_change.fetch_new_post( + res2 = await mock_status_change(AsyncClient()).fetch_new_post( "dummy", [user_info_factory([1, 2], [])] ) assert len(res2) == 1 posts = res2[0][1] assert len(posts) == 1 assert posts[0].text == "on" - res3 = await mock_status_change.fetch_new_post( + res3 = await mock_status_change(AsyncClient()).fetch_new_post( "dummy", [ user_info_factory([1, 2], []), @@ -444,7 +456,7 @@ async def test_status_change(mock_status_change, user_info_factory): assert len(res3[0][1]) == 1 assert res3[0][1][0].text == "off" assert len(res3[1][1]) == 0 - res4 = await mock_status_change.fetch_new_post( + res4 = await mock_status_change(AsyncClient()).fetch_new_post( "dummy", [user_info_factory([1, 2], [])] ) assert len(res4) == 0