nonebot-bison/tests/platforms/test_platform.py
2022-02-09 21:05:14 +08:00

447 lines
13 KiB
Python

import sys
import typing
from typing import Any, Optional
import pytest
if typing.TYPE_CHECKING:
import sys
sys.path.append("./src/plugins")
import nonebot_bison
from nonebot_bison.types import *
from nonebot_bison.post import Post
from time import time
now = time()
passed = now - 3 * 60 * 60
raw_post_list_1 = [
{"id": 1, "text": "p1", "date": now, "tags": ["tag1"], "category": 1}
]
raw_post_list_2 = raw_post_list_1 + [
{"id": 2, "text": "p2", "date": now, "tags": ["tag1"], "category": 1},
{"id": 3, "text": "p3", "date": now, "tags": ["tag2"], "category": 2},
{"id": 4, "text": "p4", "date": now, "tags": ["tag2"], "category": 3},
]
@pytest.fixture
def dummy_user(plugin_module: "nonebot_bison"):
user = plugin_module.types.User("123", "group")
return user
@pytest.fixture
def user_info_factory(plugin_module: "nonebot_bison", dummy_user):
def _user_info(category_getter, tag_getter):
return plugin_module.types.UserSubInfo(dummy_user, category_getter, tag_getter)
return _user_info
@pytest.fixture
def mock_platform_without_cats_tags(plugin_module: "nonebot_bison"):
class MockPlatform(plugin_module.platform.platform.NewMessage):
platform_name = "mock_platform"
name = "Mock Platform"
enabled = True
is_common = True
schedule_interval = 10
enable_tag = False
categories = {}
has_target = True
def __init__(self):
self.sub_index = 0
super().__init__()
@staticmethod
async def get_target_name(_: "Target"):
return "MockPlatform"
def get_id(self, post: "RawPost") -> Any:
return post["id"]
def get_date(self, raw_post: "RawPost") -> float:
return raw_post["date"]
async def parse(self, raw_post: "RawPost") -> "Post":
return plugin_module.post.Post(
"mock_platform",
raw_post["text"],
"http://t.tt/" + str(self.get_id(raw_post)),
target_name="Mock",
)
async def get_sub_list(self, _: "Target"):
if self.sub_index == 0:
self.sub_index += 1
return raw_post_list_1
else:
return raw_post_list_2
return MockPlatform()
@pytest.fixture
def mock_platform(plugin_module: "nonebot_bison"):
class MockPlatform(plugin_module.platform.platform.NewMessage):
platform_name = "mock_platform"
name = "Mock Platform"
enabled = True
is_common = True
schedule_interval = 10
enable_tag = True
has_target = True
categories = {
1: "转发",
2: "视频",
}
def __init__(self):
self.sub_index = 0
super().__init__()
@staticmethod
async def get_target_name(_: "Target"):
return "MockPlatform"
def get_id(self, post: "RawPost") -> Any:
return post["id"]
def get_date(self, raw_post: "RawPost") -> float:
return raw_post["date"]
def get_tags(self, raw_post: "RawPost") -> list["Tag"]:
return raw_post["tags"]
def get_category(self, raw_post: "RawPost") -> "Category":
return raw_post["category"]
async def parse(self, raw_post: "RawPost") -> "Post":
return plugin_module.post.Post(
"mock_platform",
raw_post["text"],
"http://t.tt/" + str(self.get_id(raw_post)),
target_name="Mock",
)
async def get_sub_list(self, _: "Target"):
if self.sub_index == 0:
self.sub_index += 1
return raw_post_list_1
else:
return raw_post_list_2
return MockPlatform()
@pytest.fixture
def mock_platform_no_target(plugin_module: "nonebot_bison"):
class MockPlatform(plugin_module.platform.platform.NewMessage):
platform_name = "mock_platform"
name = "Mock Platform"
enabled = True
is_common = True
schedule_type = "interval"
schedule_kw = {"seconds": 30}
enable_tag = True
has_target = False
categories = {1: "转发", 2: "视频", 3: "不支持"}
def __init__(self):
self.sub_index = 0
super().__init__()
@staticmethod
async def get_target_name(_: "Target"):
return "MockPlatform"
def get_id(self, post: "RawPost") -> Any:
return post["id"]
def get_date(self, raw_post: "RawPost") -> float:
return raw_post["date"]
def get_tags(self, raw_post: "RawPost") -> list["Tag"]:
return raw_post["tags"]
def get_category(self, raw_post: "RawPost") -> "Category":
if raw_post["category"] == 3:
raise plugin_module.platform.platform.CategoryNotSupport()
return raw_post["category"]
async def parse(self, raw_post: "RawPost") -> "Post":
return plugin_module.post.Post(
"mock_platform",
raw_post["text"],
"http://t.tt/" + str(self.get_id(raw_post)),
target_name="Mock",
)
async def get_sub_list(self, _: "Target"):
if self.sub_index == 0:
self.sub_index += 1
return raw_post_list_1
else:
return raw_post_list_2
return MockPlatform()
@pytest.fixture
def mock_platform_no_target_2(plugin_module: "nonebot_bison"):
class MockPlatform(plugin_module.platform.platform.NewMessage):
platform_name = "mock_platform"
name = "Mock Platform"
enabled = True
schedule_type = "interval"
schedule_kw = {"seconds": 30}
is_common = True
enable_tag = True
has_target = False
categories = {
4: "leixing4",
5: "leixing5",
}
def __init__(self):
self.sub_index = 0
super().__init__()
@staticmethod
async def get_target_name(_: "Target"):
return "MockPlatform"
def get_id(self, post: "RawPost") -> Any:
return post["id"]
def get_date(self, raw_post: "RawPost") -> float:
return raw_post["date"]
def get_tags(self, raw_post: "RawPost") -> list["Tag"]:
return raw_post["tags"]
def get_category(self, raw_post: "RawPost") -> "Category":
return raw_post["category"]
async def parse(self, raw_post: "RawPost") -> "Post":
return plugin_module.post.Post(
"mock_platform_2",
raw_post["text"],
"http://t.tt/" + str(self.get_id(raw_post)),
target_name="Mock",
)
async def get_sub_list(self, _: "Target"):
list_1 = [
{"id": 5, "text": "p5", "date": now, "tags": ["tag1"], "category": 4}
]
list_2 = list_1 + [
{"id": 6, "text": "p6", "date": now, "tags": ["tag1"], "category": 4},
{"id": 7, "text": "p7", "date": now, "tags": ["tag2"], "category": 5},
]
if self.sub_index == 0:
self.sub_index += 1
return list_1
else:
return list_2
return MockPlatform()
@pytest.fixture
def mock_status_change(plugin_module: "nonebot_bison"):
class MockPlatform(plugin_module.platform.platform.StatusChange):
platform_name = "mock_platform"
name = "Mock Platform"
enabled = True
is_common = True
enable_tag = False
schedule_type = "interval"
schedule_kw = {"seconds": 10}
has_target = False
categories = {
1: "转发",
2: "视频",
}
def __init__(self):
self.sub_index = 0
super().__init__()
async def get_status(self, _: "Target"):
if self.sub_index == 0:
self.sub_index += 1
return {"s": False}
elif self.sub_index == 1:
self.sub_index += 1
return {"s": True}
else:
return {"s": False}
def compare_status(self, target, old_status, new_status) -> list["RawPost"]:
if old_status["s"] == False and new_status["s"] == True:
return [{"text": "on", "cat": 1}]
elif old_status["s"] == True and new_status["s"] == False:
return [{"text": "off", "cat": 2}]
return []
async def parse(self, raw_post) -> "Post":
return plugin_module.post.Post("mock_status", raw_post["text"], "")
def get_category(self, raw_post):
return raw_post["cat"]
return MockPlatform()
@pytest.mark.asyncio
async def test_new_message_target_without_cats_tags(
mock_platform_without_cats_tags, user_info_factory
):
res1 = await mock_platform_without_cats_tags.fetch_new_post(
"dummy", [user_info_factory(lambda _: [1, 2], lambda _: [])]
)
assert len(res1) == 0
res2 = await mock_platform_without_cats_tags.fetch_new_post(
"dummy",
[
user_info_factory(lambda _: [], lambda _: []),
],
)
assert len(res2) == 1
posts_1 = res2[0][1]
assert len(posts_1) == 3
id_set_1 = set(map(lambda x: x.text, posts_1))
assert "p2" in id_set_1 and "p3" in id_set_1 and "p4" in id_set_1
@pytest.mark.asyncio
async def test_new_message_target(mock_platform, user_info_factory):
res1 = await mock_platform.fetch_new_post(
"dummy", [user_info_factory(lambda _: [1, 2], lambda _: [])]
)
assert len(res1) == 0
res2 = await mock_platform.fetch_new_post(
"dummy",
[
user_info_factory(lambda _: [1, 2], lambda _: []),
user_info_factory(lambda _: [1], lambda _: []),
user_info_factory(lambda _: [1, 2], lambda _: ["tag1"]),
],
)
assert len(res2) == 3
posts_1 = res2[0][1]
posts_2 = res2[1][1]
posts_3 = res2[2][1]
assert len(posts_1) == 2
assert len(posts_2) == 1
assert len(posts_3) == 1
id_set_1 = set(map(lambda x: x.text, posts_1))
id_set_2 = set(map(lambda x: x.text, posts_2))
id_set_3 = set(map(lambda x: x.text, posts_3))
assert "p2" in id_set_1 and "p3" in id_set_1
assert "p2" in id_set_2
assert "p2" in id_set_3
@pytest.mark.asyncio
async def test_new_message_no_target(mock_platform_no_target, user_info_factory):
res1 = await mock_platform_no_target.fetch_new_post(
"dummy", [user_info_factory(lambda _: [1, 2], lambda _: [])]
)
assert len(res1) == 0
res2 = await mock_platform_no_target.fetch_new_post(
"dummy",
[
user_info_factory(lambda _: [1, 2], lambda _: []),
user_info_factory(lambda _: [1], lambda _: []),
user_info_factory(lambda _: [1, 2], lambda _: ["tag1"]),
],
)
assert len(res2) == 3
posts_1 = res2[0][1]
posts_2 = res2[1][1]
posts_3 = res2[2][1]
assert len(posts_1) == 2
assert len(posts_2) == 1
assert len(posts_3) == 1
id_set_1 = set(map(lambda x: x.text, posts_1))
id_set_2 = set(map(lambda x: x.text, posts_2))
id_set_3 = set(map(lambda x: x.text, posts_3))
assert "p2" in id_set_1 and "p3" in id_set_1
assert "p2" in id_set_2
assert "p2" in id_set_3
res3 = await mock_platform_no_target.fetch_new_post(
"dummy", [user_info_factory(lambda _: [1, 2], lambda _: [])]
)
assert len(res3) == 0
@pytest.mark.asyncio
async def test_status_change(mock_status_change, user_info_factory):
res1 = await mock_status_change.fetch_new_post(
"dummy", [user_info_factory(lambda _: [1, 2], lambda _: [])]
)
assert len(res1) == 0
res2 = await mock_status_change.fetch_new_post(
"dummy", [user_info_factory(lambda _: [1, 2], lambda _: [])]
)
assert len(res2) == 1
posts = res2[0][1]
assert len(posts) == 1
assert posts[0].text == "on"
res3 = await mock_status_change.fetch_new_post(
"dummy",
[
user_info_factory(lambda _: [1, 2], lambda _: []),
user_info_factory(lambda _: [1], lambda _: []),
],
)
assert len(res3) == 2
assert len(res3[0][1]) == 1
assert res3[0][1][0].text == "off"
assert len(res3[1][1]) == 0
res4 = await mock_status_change.fetch_new_post(
"dummy", [user_info_factory(lambda _: [1, 2], lambda _: [])]
)
assert len(res4) == 0
@pytest.mark.asyncio
async def test_group(
plugin_module: "nonebot_bison",
mock_platform_no_target,
mock_platform_no_target_2,
user_info_factory,
):
group_platform = plugin_module.platform.platform.NoTargetGroup(
[mock_platform_no_target, mock_platform_no_target_2]
)
res1 = await group_platform.fetch_new_post(
"dummy", [user_info_factory(lambda _: [1, 4], lambda _: [])]
)
assert len(res1) == 0
res2 = await group_platform.fetch_new_post(
"dummy", [user_info_factory(lambda _: [1, 4], lambda _: [])]
)
assert len(res2) == 1
posts = res2[0][1]
assert len(posts) == 2
id_set_2 = set(map(lambda x: x.text, posts))
assert "p2" in id_set_2 and "p6" in id_set_2
res3 = await group_platform.fetch_new_post(
"dummy", [user_info_factory(lambda _: [1, 4], lambda _: [])]
)
assert len(res3) == 0