Merge pull request #129 from felinae98/fix-new-message

fix platform store init error
This commit is contained in:
felinae98 2022-10-17 20:20:23 +08:00 committed by GitHub
commit fdaee5f227
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 43 deletions

View File

@ -39,9 +39,11 @@ class RegistryMeta(type):
class PlatformMeta(RegistryMeta):
categories: dict[Category, str]
store: dict[Target, Any]
def __init__(cls, name, bases, namespace, **kwargs):
cls.reverse_category = {}
cls.store = {}
if hasattr(cls, "categories") and cls.categories:
for key, val in cls.categories.items():
cls.reverse_category[val] = key
@ -61,7 +63,6 @@ class Platform(metaclass=PlatformABCMeta, base=True):
has_target: bool
categories: dict[Category, str]
enable_tag: bool
store: dict[Target, Any]
platform_name: str
parse_target_promot: Optional[str] = None
registry: list[Type["Platform"]]
@ -110,7 +111,6 @@ class Platform(metaclass=PlatformABCMeta, base=True):
def __init__(self, client: AsyncClient):
super().__init__()
self.store = dict()
self.client = client
class ParseTargetException(Exception):
@ -124,11 +124,13 @@ class Platform(metaclass=PlatformABCMeta, base=True):
def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]:
"Return Tag list of given RawPost"
def get_stored_data(self, target: Target) -> Any:
return self.store.get(target)
@classmethod
def get_stored_data(cls, target: Target) -> Any:
return cls.store.get(target)
def set_stored_data(self, target: Target, data: Any):
self.store[target] = data
@classmethod
def set_stored_data(cls, target: Target, data: Any):
cls.store[target] = data
def tag_separator(self, stored_tags: list[Tag]) -> tuple[list[Tag], list[Tag]]:
"""返回分离好的正反tag元组"""

View File

@ -2,6 +2,7 @@ from time import time
from typing import Any, Optional
import pytest
from httpx import AsyncClient
from nonebug.app import App
now = time()
@ -53,8 +54,9 @@ def mock_platform_without_cats_tags(app: App):
categories = {}
has_target = True
sub_index = 0
def __init__(self, client):
self.sub_index = 0
super().__init__(client)
@classmethod
@ -75,14 +77,15 @@ def mock_platform_without_cats_tags(app: App):
target_name="Mock",
)
async def get_sub_list(self, _: "Target"):
if self.sub_index == 0:
self.sub_index += 1
@classmethod
async def get_sub_list(cls, _: "Target"):
if cls.sub_index == 0:
cls.sub_index += 1
return raw_post_list_1
else:
return raw_post_list_2
return MockPlatform(None)
return MockPlatform
@pytest.fixture
@ -112,8 +115,9 @@ def mock_platform(app: App):
Category(2): "视频",
}
sub_index = 0
def __init__(self, client):
self.sub_index = 0
super().__init__(client)
@staticmethod
@ -140,14 +144,15 @@ def mock_platform(app: App):
target_name="Mock",
)
async def get_sub_list(self, _: "Target"):
if self.sub_index == 0:
self.sub_index += 1
@classmethod
async def get_sub_list(cls, _: "Target"):
if cls.sub_index == 0:
cls.sub_index += 1
return raw_post_list_1
else:
return raw_post_list_2
return MockPlatform(None)
return MockPlatform
@pytest.fixture
@ -180,8 +185,9 @@ def mock_platform_no_target(app: App, mock_scheduler_conf):
has_target = False
categories = {Category(1): "转发", Category(2): "视频", Category(3): "不支持"}
sub_index = 0
def __init__(self, client):
self.sub_index = 0
super().__init__(client)
@staticmethod
@ -210,9 +216,10 @@ def mock_platform_no_target(app: App, mock_scheduler_conf):
target_name="Mock",
)
async def get_sub_list(self, _: "Target"):
if self.sub_index == 0:
self.sub_index += 1
@classmethod
async def get_sub_list(cls, _: "Target"):
if cls.sub_index == 0:
cls.sub_index += 1
return raw_post_list_1
else:
return raw_post_list_2
@ -241,8 +248,9 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf):
Category(5): "leixing5",
}
sub_index = 0
def __init__(self, client):
self.sub_index = 0
super().__init__(client)
@classmethod
@ -269,7 +277,8 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf):
target_name="Mock",
)
async def get_sub_list(self, _: "Target"):
@classmethod
async def get_sub_list(cls, _: "Target"):
list_1 = [
{"id": 5, "text": "p5", "date": now, "tags": ["tag1"], "category": 4}
]
@ -278,8 +287,8 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf):
{"id": 6, "text": "p6", "date": now, "tags": ["tag1"], "category": 4},
{"id": 7, "text": "p7", "date": now, "tags": ["tag2"], "category": 5},
]
if self.sub_index == 0:
self.sub_index += 1
if cls.sub_index == 0:
cls.sub_index += 1
return list_1
else:
return list_2
@ -308,16 +317,18 @@ def mock_status_change(app: App):
Category(2): "视频",
}
sub_index = 0
def __init__(self, client):
self.sub_index = 0
super().__init__(client)
async def get_status(self, _: "Target"):
if self.sub_index == 0:
self.sub_index += 1
@classmethod
async def get_status(cls, _: "Target"):
if cls.sub_index == 0:
cls.sub_index += 1
return {"s": False}
elif self.sub_index == 1:
self.sub_index += 1
elif cls.sub_index == 1:
cls.sub_index += 1
return {"s": True}
else:
return {"s": False}
@ -335,18 +346,18 @@ def mock_status_change(app: App):
def get_category(self, raw_post):
return raw_post["cat"]
return MockPlatform(None)
return MockPlatform
@pytest.mark.asyncio
async def test_new_message_target_without_cats_tags(
mock_platform_without_cats_tags, user_info_factory
):
res1 = await mock_platform_without_cats_tags.fetch_new_post(
res1 = await mock_platform_without_cats_tags(AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])]
)
assert len(res1) == 0
res2 = await mock_platform_without_cats_tags.fetch_new_post(
res2 = await mock_platform_without_cats_tags(AsyncClient()).fetch_new_post(
"dummy",
[
user_info_factory([], []),
@ -361,9 +372,11 @@ async def test_new_message_target_without_cats_tags(
@pytest.mark.asyncio
async def test_new_message_target(mock_platform, user_info_factory):
res1 = await mock_platform.fetch_new_post("dummy", [user_info_factory([1, 2], [])])
res1 = await mock_platform(AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])]
)
assert len(res1) == 0
res2 = await mock_platform.fetch_new_post(
res2 = await mock_platform(AsyncClient()).fetch_new_post(
"dummy",
[
user_info_factory([1, 2], []),
@ -388,12 +401,11 @@ async def test_new_message_target(mock_platform, user_info_factory):
@pytest.mark.asyncio
async def test_new_message_no_target(mock_platform_no_target, user_info_factory):
mock_platform_no_target = mock_platform_no_target(None)
res1 = await mock_platform_no_target.fetch_new_post(
res1 = await mock_platform_no_target(AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])]
)
assert len(res1) == 0
res2 = await mock_platform_no_target.fetch_new_post(
res2 = await mock_platform_no_target(AsyncClient()).fetch_new_post(
"dummy",
[
user_info_factory([1, 2], []),
@ -414,7 +426,7 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory)
assert "p2" in id_set_1 and "p3" in id_set_1
assert "p2" in id_set_2
assert "p2" in id_set_3
res3 = await mock_platform_no_target.fetch_new_post(
res3 = await mock_platform_no_target(AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])]
)
assert len(res3) == 0
@ -422,18 +434,18 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory)
@pytest.mark.asyncio
async def test_status_change(mock_status_change, user_info_factory):
res1 = await mock_status_change.fetch_new_post(
res1 = await mock_status_change(AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])]
)
assert len(res1) == 0
res2 = await mock_status_change.fetch_new_post(
res2 = await mock_status_change(AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])]
)
assert len(res2) == 1
posts = res2[0][1]
assert len(posts) == 1
assert posts[0].text == "on"
res3 = await mock_status_change.fetch_new_post(
res3 = await mock_status_change(AsyncClient()).fetch_new_post(
"dummy",
[
user_info_factory([1, 2], []),
@ -444,7 +456,7 @@ async def test_status_change(mock_status_change, user_info_factory):
assert len(res3[0][1]) == 1
assert res3[0][1][0].text == "off"
assert len(res3[1][1]) == 0
res4 = await mock_status_change.fetch_new_post(
res4 = await mock_status_change(AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])]
)
assert len(res4) == 0