2021-11-17 16:57:04 +08:00

147 lines
5.5 KiB
Python

import base64
from dataclasses import dataclass, field
from functools import reduce
from io import BytesIO
from typing import Optional, Union
from PIL import Image
import httpx
from nonebot import logger
from nonebot.adapters.cqhttp.message import Message, MessageSegment
from .plugin_config import plugin_config
from .utils import parse_text
@dataclass
class Post:
target_type: str
text: str
url: Optional[str] = None
target_name: Optional[str] = None
compress: bool = False
override_use_pic: Optional[bool] = None
pics: Union[list[Union[str,bytes]], list[str], list[bytes]] = field(default_factory=list)
extra_msg: list[Message] = field(default_factory=list)
_message: Optional[list] = None
def _use_pic(self):
if not self.override_use_pic is None:
return self.override_use_pic
return plugin_config.bison_use_pic
async def _pic_url_to_image(self, data: Union[str, bytes]) -> Image.Image:
pic_buffer = BytesIO()
if isinstance(data, str):
async with httpx.AsyncClient() as client:
res = await client.get(data)
pic_buffer.write(res.content)
else:
pic_buffer.write(data)
return Image.open(pic_buffer)
def _check_image_square(self, size: tuple[int, int]) -> bool:
return abs(size[0] - size[1]) / size[0] < 0.01
async def _pic_merge(self) -> None:
if len(self.pics) < 3:
return
first_image = await self._pic_url_to_image(self.pics[0])
if not self._check_image_square(first_image.size):
return
images: list[Image.Image] = [first_image]
# first row
for i in range(1, 3):
cur_img = await self._pic_url_to_image(self.pics[i])
if not self._check_image_square(cur_img.size):
return
if cur_img.size[1] != images[0].size[1]: # height not equal
return
images.append(cur_img)
_tmp = 0
x_coord = [0]
for i in range(3):
_tmp += images[i].size[0]
x_coord.append(_tmp)
y_coord = [0, first_image.size[1]]
async def process_row(row: int) -> bool:
if len(self.pics) < (row + 1) * 3:
return False
row_first_img = await self._pic_url_to_image(self.pics[row * 3])
if not self._check_image_square(row_first_img.size):
return False
if row_first_img.size[0] != images[0].size[0]:
return False
image_row: list[Image.Image] = [row_first_img]
for i in range(row * 3 + 1, row * 3 + 3):
cur_img = await self._pic_url_to_image(self.pics[i])
if not self._check_image_square(cur_img.size):
return False
if cur_img.size[1] != row_first_img.size[1]:
return False
if cur_img.size[0] != images[i % 3].size[0]:
return False
image_row.append(cur_img)
images.extend(image_row)
y_coord.append(y_coord[-1] + row_first_img.size[1])
return True
if await process_row(1):
matrix = (3,2)
else:
matrix = (3,1)
if await process_row(2):
matrix = (3,3)
logger.info('trigger merge image')
target = Image.new('RGB', (x_coord[-1], y_coord[-1]))
for y in range(matrix[1]):
for x in range(matrix[0]):
target.paste(images[y * matrix[0] + x], (
x_coord[x], y_coord[y], x_coord[x+1], y_coord[y+1]
))
target_io = BytesIO()
target.save(target_io, 'JPEG')
self.pics = self.pics[matrix[0] * matrix[1]: ]
self.pics.insert(0, target_io.getvalue())
async def generate_messages(self):
if self._message is None:
await self._pic_merge()
msgs = []
text = ''
if self.text:
if self._use_pic():
text += '{}'.format(self.text)
else:
text += '{}'.format(self.text if len(self.text) < 500 else self.text[:500] + '...')
text += '\n来源: {}'.format(self.target_type)
if self.target_name:
text += ' {}'.format(self.target_name)
if self._use_pic():
msgs.append(await parse_text(text))
if not self.target_type == 'rss' and self.url:
msgs.append(MessageSegment.text(self.url))
else:
if self.url:
text += ' \n详情: {}'.format(self.url)
msgs.append(MessageSegment.text(text))
for pic in self.pics:
# if isinstance(pic, bytes):
# pic = 'base64://' + base64.b64encode(pic).decode()
# msgs.append(Message("[CQ:image,file={url}]".format(url=pic)))
msgs.append(MessageSegment.image(pic))
if self.compress:
msgs = [reduce(lambda x, y: x.append(y), msgs, Message())]
msgs.extend(self.extra_msg)
self._message = msgs
return self._message
def __str__(self):
return 'type: {}\nfrom: {}\ntext: {}\nurl: {}\npic: {}'.format(
self.target_type,
self.target_name,
self.text if len(self.text) < 500 else self.text[:500] + '...',
self.url,
', '.join(map(lambda x: 'b64img' if isinstance(x, bytes) or x.startswith('base64') else x, self.pics))
)