Merge pull request #129 from felinae98/fix-new-message

fix platform store init error
This commit is contained in:
felinae98 2022-10-17 20:20:23 +08:00 committed by GitHub
commit fdaee5f227
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 43 deletions

View File

@ -39,9 +39,11 @@ class RegistryMeta(type):
class PlatformMeta(RegistryMeta): class PlatformMeta(RegistryMeta):
categories: dict[Category, str] categories: dict[Category, str]
store: dict[Target, Any]
def __init__(cls, name, bases, namespace, **kwargs): def __init__(cls, name, bases, namespace, **kwargs):
cls.reverse_category = {} cls.reverse_category = {}
cls.store = {}
if hasattr(cls, "categories") and cls.categories: if hasattr(cls, "categories") and cls.categories:
for key, val in cls.categories.items(): for key, val in cls.categories.items():
cls.reverse_category[val] = key cls.reverse_category[val] = key
@ -61,7 +63,6 @@ class Platform(metaclass=PlatformABCMeta, base=True):
has_target: bool has_target: bool
categories: dict[Category, str] categories: dict[Category, str]
enable_tag: bool enable_tag: bool
store: dict[Target, Any]
platform_name: str platform_name: str
parse_target_promot: Optional[str] = None parse_target_promot: Optional[str] = None
registry: list[Type["Platform"]] registry: list[Type["Platform"]]
@ -110,7 +111,6 @@ class Platform(metaclass=PlatformABCMeta, base=True):
def __init__(self, client: AsyncClient): def __init__(self, client: AsyncClient):
super().__init__() super().__init__()
self.store = dict()
self.client = client self.client = client
class ParseTargetException(Exception): class ParseTargetException(Exception):
@ -124,11 +124,13 @@ class Platform(metaclass=PlatformABCMeta, base=True):
def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]: def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]:
"Return Tag list of given RawPost" "Return Tag list of given RawPost"
def get_stored_data(self, target: Target) -> Any: @classmethod
return self.store.get(target) def get_stored_data(cls, target: Target) -> Any:
return cls.store.get(target)
def set_stored_data(self, target: Target, data: Any): @classmethod
self.store[target] = data def set_stored_data(cls, target: Target, data: Any):
cls.store[target] = data
def tag_separator(self, stored_tags: list[Tag]) -> tuple[list[Tag], list[Tag]]: def tag_separator(self, stored_tags: list[Tag]) -> tuple[list[Tag], list[Tag]]:
"""返回分离好的正反tag元组""" """返回分离好的正反tag元组"""

View File

@ -2,6 +2,7 @@ from time import time
from typing import Any, Optional from typing import Any, Optional
import pytest import pytest
from httpx import AsyncClient
from nonebug.app import App from nonebug.app import App
now = time() now = time()
@ -53,8 +54,9 @@ def mock_platform_without_cats_tags(app: App):
categories = {} categories = {}
has_target = True has_target = True
sub_index = 0
def __init__(self, client): def __init__(self, client):
self.sub_index = 0
super().__init__(client) super().__init__(client)
@classmethod @classmethod
@ -75,14 +77,15 @@ def mock_platform_without_cats_tags(app: App):
target_name="Mock", target_name="Mock",
) )
async def get_sub_list(self, _: "Target"): @classmethod
if self.sub_index == 0: async def get_sub_list(cls, _: "Target"):
self.sub_index += 1 if cls.sub_index == 0:
cls.sub_index += 1
return raw_post_list_1 return raw_post_list_1
else: else:
return raw_post_list_2 return raw_post_list_2
return MockPlatform(None) return MockPlatform
@pytest.fixture @pytest.fixture
@ -112,8 +115,9 @@ def mock_platform(app: App):
Category(2): "视频", Category(2): "视频",
} }
sub_index = 0
def __init__(self, client): def __init__(self, client):
self.sub_index = 0
super().__init__(client) super().__init__(client)
@staticmethod @staticmethod
@ -140,14 +144,15 @@ def mock_platform(app: App):
target_name="Mock", target_name="Mock",
) )
async def get_sub_list(self, _: "Target"): @classmethod
if self.sub_index == 0: async def get_sub_list(cls, _: "Target"):
self.sub_index += 1 if cls.sub_index == 0:
cls.sub_index += 1
return raw_post_list_1 return raw_post_list_1
else: else:
return raw_post_list_2 return raw_post_list_2
return MockPlatform(None) return MockPlatform
@pytest.fixture @pytest.fixture
@ -180,8 +185,9 @@ def mock_platform_no_target(app: App, mock_scheduler_conf):
has_target = False has_target = False
categories = {Category(1): "转发", Category(2): "视频", Category(3): "不支持"} categories = {Category(1): "转发", Category(2): "视频", Category(3): "不支持"}
sub_index = 0
def __init__(self, client): def __init__(self, client):
self.sub_index = 0
super().__init__(client) super().__init__(client)
@staticmethod @staticmethod
@ -210,9 +216,10 @@ def mock_platform_no_target(app: App, mock_scheduler_conf):
target_name="Mock", target_name="Mock",
) )
async def get_sub_list(self, _: "Target"): @classmethod
if self.sub_index == 0: async def get_sub_list(cls, _: "Target"):
self.sub_index += 1 if cls.sub_index == 0:
cls.sub_index += 1
return raw_post_list_1 return raw_post_list_1
else: else:
return raw_post_list_2 return raw_post_list_2
@ -241,8 +248,9 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf):
Category(5): "leixing5", Category(5): "leixing5",
} }
sub_index = 0
def __init__(self, client): def __init__(self, client):
self.sub_index = 0
super().__init__(client) super().__init__(client)
@classmethod @classmethod
@ -269,7 +277,8 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf):
target_name="Mock", target_name="Mock",
) )
async def get_sub_list(self, _: "Target"): @classmethod
async def get_sub_list(cls, _: "Target"):
list_1 = [ list_1 = [
{"id": 5, "text": "p5", "date": now, "tags": ["tag1"], "category": 4} {"id": 5, "text": "p5", "date": now, "tags": ["tag1"], "category": 4}
] ]
@ -278,8 +287,8 @@ def mock_platform_no_target_2(app: App, mock_scheduler_conf):
{"id": 6, "text": "p6", "date": now, "tags": ["tag1"], "category": 4}, {"id": 6, "text": "p6", "date": now, "tags": ["tag1"], "category": 4},
{"id": 7, "text": "p7", "date": now, "tags": ["tag2"], "category": 5}, {"id": 7, "text": "p7", "date": now, "tags": ["tag2"], "category": 5},
] ]
if self.sub_index == 0: if cls.sub_index == 0:
self.sub_index += 1 cls.sub_index += 1
return list_1 return list_1
else: else:
return list_2 return list_2
@ -308,16 +317,18 @@ def mock_status_change(app: App):
Category(2): "视频", Category(2): "视频",
} }
sub_index = 0
def __init__(self, client): def __init__(self, client):
self.sub_index = 0
super().__init__(client) super().__init__(client)
async def get_status(self, _: "Target"): @classmethod
if self.sub_index == 0: async def get_status(cls, _: "Target"):
self.sub_index += 1 if cls.sub_index == 0:
cls.sub_index += 1
return {"s": False} return {"s": False}
elif self.sub_index == 1: elif cls.sub_index == 1:
self.sub_index += 1 cls.sub_index += 1
return {"s": True} return {"s": True}
else: else:
return {"s": False} return {"s": False}
@ -335,18 +346,18 @@ def mock_status_change(app: App):
def get_category(self, raw_post): def get_category(self, raw_post):
return raw_post["cat"] return raw_post["cat"]
return MockPlatform(None) return MockPlatform
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_message_target_without_cats_tags( async def test_new_message_target_without_cats_tags(
mock_platform_without_cats_tags, user_info_factory mock_platform_without_cats_tags, user_info_factory
): ):
res1 = await mock_platform_without_cats_tags.fetch_new_post( res1 = await mock_platform_without_cats_tags(AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])] "dummy", [user_info_factory([1, 2], [])]
) )
assert len(res1) == 0 assert len(res1) == 0
res2 = await mock_platform_without_cats_tags.fetch_new_post( res2 = await mock_platform_without_cats_tags(AsyncClient()).fetch_new_post(
"dummy", "dummy",
[ [
user_info_factory([], []), user_info_factory([], []),
@ -361,9 +372,11 @@ async def test_new_message_target_without_cats_tags(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_message_target(mock_platform, user_info_factory): async def test_new_message_target(mock_platform, user_info_factory):
res1 = await mock_platform.fetch_new_post("dummy", [user_info_factory([1, 2], [])]) res1 = await mock_platform(AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])]
)
assert len(res1) == 0 assert len(res1) == 0
res2 = await mock_platform.fetch_new_post( res2 = await mock_platform(AsyncClient()).fetch_new_post(
"dummy", "dummy",
[ [
user_info_factory([1, 2], []), user_info_factory([1, 2], []),
@ -388,12 +401,11 @@ async def test_new_message_target(mock_platform, user_info_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_message_no_target(mock_platform_no_target, user_info_factory): async def test_new_message_no_target(mock_platform_no_target, user_info_factory):
mock_platform_no_target = mock_platform_no_target(None) res1 = await mock_platform_no_target(AsyncClient()).fetch_new_post(
res1 = await mock_platform_no_target.fetch_new_post(
"dummy", [user_info_factory([1, 2], [])] "dummy", [user_info_factory([1, 2], [])]
) )
assert len(res1) == 0 assert len(res1) == 0
res2 = await mock_platform_no_target.fetch_new_post( res2 = await mock_platform_no_target(AsyncClient()).fetch_new_post(
"dummy", "dummy",
[ [
user_info_factory([1, 2], []), user_info_factory([1, 2], []),
@ -414,7 +426,7 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory)
assert "p2" in id_set_1 and "p3" in id_set_1 assert "p2" in id_set_1 and "p3" in id_set_1
assert "p2" in id_set_2 assert "p2" in id_set_2
assert "p2" in id_set_3 assert "p2" in id_set_3
res3 = await mock_platform_no_target.fetch_new_post( res3 = await mock_platform_no_target(AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])] "dummy", [user_info_factory([1, 2], [])]
) )
assert len(res3) == 0 assert len(res3) == 0
@ -422,18 +434,18 @@ async def test_new_message_no_target(mock_platform_no_target, user_info_factory)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_status_change(mock_status_change, user_info_factory): async def test_status_change(mock_status_change, user_info_factory):
res1 = await mock_status_change.fetch_new_post( res1 = await mock_status_change(AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])] "dummy", [user_info_factory([1, 2], [])]
) )
assert len(res1) == 0 assert len(res1) == 0
res2 = await mock_status_change.fetch_new_post( res2 = await mock_status_change(AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])] "dummy", [user_info_factory([1, 2], [])]
) )
assert len(res2) == 1 assert len(res2) == 1
posts = res2[0][1] posts = res2[0][1]
assert len(posts) == 1 assert len(posts) == 1
assert posts[0].text == "on" assert posts[0].text == "on"
res3 = await mock_status_change.fetch_new_post( res3 = await mock_status_change(AsyncClient()).fetch_new_post(
"dummy", "dummy",
[ [
user_info_factory([1, 2], []), user_info_factory([1, 2], []),
@ -444,7 +456,7 @@ async def test_status_change(mock_status_change, user_info_factory):
assert len(res3[0][1]) == 1 assert len(res3[0][1]) == 1
assert res3[0][1][0].text == "off" assert res3[0][1][0].text == "off"
assert len(res3[1][1]) == 0 assert len(res3[1][1]) == 0
res4 = await mock_status_change.fetch_new_post( res4 = await mock_status_change(AsyncClient()).fetch_new_post(
"dummy", [user_info_factory([1, 2], [])] "dummy", [user_info_factory([1, 2], [])]
) )
assert len(res4) == 0 assert len(res4) == 0