mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2026-05-09 18:27:56 +08:00
🚧 remove User type
This commit is contained in:
@@ -1,53 +1,57 @@
|
||||
""" 提供获取 Bot 的方法 """
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from typing import Any, Optional
|
||||
|
||||
import nonebot
|
||||
from nonebot import get_driver, on_notice
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.adapters.onebot.v11 import Bot as Ob11Bot
|
||||
from nonebot.adapters.onebot.v11 import (
|
||||
Bot,
|
||||
FriendAddNoticeEvent,
|
||||
GroupDecreaseNoticeEvent,
|
||||
GroupIncreaseNoticeEvent,
|
||||
)
|
||||
|
||||
from ..types import User
|
||||
from nonebot_plugin_saa import PlatformTarget, TargetQQGroup, TargetQQPrivate
|
||||
|
||||
GROUP: dict[int, list[Bot]] = {}
|
||||
USER: dict[int, list[Bot]] = {}
|
||||
BOT_CACHE: dict[PlatformTarget, list[Bot]] = defaultdict(list)
|
||||
|
||||
|
||||
def get_bots() -> list[Bot]:
|
||||
"""获取所有 OneBot 11 Bot"""
|
||||
# TODO: support ob12
|
||||
bots = []
|
||||
for bot in nonebot.get_bots().values():
|
||||
if isinstance(bot, Bot):
|
||||
if isinstance(bot, Ob11Bot):
|
||||
bots.append(bot)
|
||||
return bots
|
||||
|
||||
|
||||
async def _refresh_ob11(bot: Ob11Bot):
|
||||
# 获取群列表
|
||||
groups = await bot.get_group_list()
|
||||
for group in groups:
|
||||
group_id = group["group_id"]
|
||||
target = TargetQQGroup(group_id=group_id)
|
||||
BOT_CACHE[target].append(bot)
|
||||
|
||||
# 获取好友列表
|
||||
users = await bot.get_friend_list()
|
||||
for user in users:
|
||||
user_id = user["user_id"]
|
||||
target = TargetQQPrivate(user_id=user_id)
|
||||
BOT_CACHE[target].append(bot)
|
||||
|
||||
|
||||
async def refresh_bots():
|
||||
"""刷新缓存的 Bot 数据"""
|
||||
GROUP.clear()
|
||||
USER.clear()
|
||||
BOT_CACHE.clear()
|
||||
for bot in get_bots():
|
||||
# 获取群列表
|
||||
groups = await bot.get_group_list()
|
||||
for group in groups:
|
||||
group_id = group["group_id"]
|
||||
if group_id not in GROUP:
|
||||
GROUP[group_id] = [bot]
|
||||
else:
|
||||
GROUP[group_id].append(bot)
|
||||
|
||||
# 获取好友列表
|
||||
users = await bot.get_friend_list()
|
||||
for user in users:
|
||||
user_id = user["user_id"]
|
||||
if user_id not in USER:
|
||||
USER[user_id] = [bot]
|
||||
else:
|
||||
USER[user_id].append(bot)
|
||||
match bot:
|
||||
case Ob11Bot():
|
||||
await _refresh_ob11(bot)
|
||||
|
||||
|
||||
driver = get_driver()
|
||||
@@ -75,15 +79,9 @@ async def _(bot: Bot, event: GroupDecreaseNoticeEvent | GroupIncreaseNoticeEvent
|
||||
await refresh_bots()
|
||||
|
||||
|
||||
def get_bot(user: User) -> Optional[Bot]:
|
||||
def get_bot(user: PlatformTarget) -> Optional[Bot]:
|
||||
"""获取 Bot"""
|
||||
bots = []
|
||||
if user.user_type == "group":
|
||||
bots = GROUP.get(user.user, [])
|
||||
|
||||
if user.user_type == "private":
|
||||
bots = USER.get(user.user, [])
|
||||
|
||||
bots = BOT_CACHE.get(user)
|
||||
if not bots:
|
||||
return
|
||||
|
||||
@@ -92,6 +90,7 @@ def get_bot(user: User) -> Optional[Bot]:
|
||||
|
||||
async def get_groups() -> list[dict[str, Any]]:
|
||||
"""获取所有群号"""
|
||||
# TODO
|
||||
all_groups: dict[int, dict[str, Any]] = {}
|
||||
for bot in get_bots():
|
||||
groups = await bot.get_group_list()
|
||||
|
||||
Reference in New Issue
Block a user