from dataclasses import dataclass, field from functools import reduce from io import BytesIO from typing import Optional, Union from nonebot.adapters.onebot.v11.message import Message, MessageSegment from nonebot.log import logger from PIL import Image from ..utils import http_client, parse_text from .abstract_post import AbstractPost, BasePost, OptionalMixin @dataclass class _Post(BasePost): target_type: str text: str url: Optional[str] = None target_name: Optional[str] = None pics: list[Union[str, bytes]] = field(default_factory=list) _message: Optional[list[MessageSegment]] = None _pic_message: Optional[list[MessageSegment]] = None async def _pic_url_to_image(self, data: Union[str, bytes]) -> Image.Image: pic_buffer = BytesIO() if isinstance(data, str): async with http_client() as client: res = await client.get(data) pic_buffer.write(res.content) else: pic_buffer.write(data) return Image.open(pic_buffer) def _check_image_square(self, size: tuple[int, int]) -> bool: return abs(size[0] - size[1]) / size[0] < 0.05 async def _pic_merge(self) -> None: if len(self.pics) < 3: return first_image = await self._pic_url_to_image(self.pics[0]) if not self._check_image_square(first_image.size): return images: list[Image.Image] = [first_image] # first row for i in range(1, 3): cur_img = await self._pic_url_to_image(self.pics[i]) if not self._check_image_square(cur_img.size): return if cur_img.size[1] != images[0].size[1]: # height not equal return images.append(cur_img) _tmp = 0 x_coord = [0] for i in range(3): _tmp += images[i].size[0] x_coord.append(_tmp) y_coord = [0, first_image.size[1]] async def process_row(row: int) -> bool: if len(self.pics) < (row + 1) * 3: return False row_first_img = await self._pic_url_to_image(self.pics[row * 3]) if not self._check_image_square(row_first_img.size): return False if row_first_img.size[0] != images[0].size[0]: return False image_row: list[Image.Image] = [row_first_img] for i in range(row * 3 + 1, row * 3 + 3): cur_img = await self._pic_url_to_image(self.pics[i]) if not self._check_image_square(cur_img.size): return False if cur_img.size[1] != row_first_img.size[1]: return False if cur_img.size[0] != images[i % 3].size[0]: return False image_row.append(cur_img) images.extend(image_row) y_coord.append(y_coord[-1] + row_first_img.size[1]) return True if await process_row(1): matrix = (3, 2) else: matrix = (3, 1) if await process_row(2): matrix = (3, 3) logger.info("trigger merge image") target = Image.new("RGB", (x_coord[-1], y_coord[-1])) for y in range(matrix[1]): for x in range(matrix[0]): target.paste( images[y * matrix[0] + x], (x_coord[x], y_coord[y], x_coord[x + 1], y_coord[y + 1]), ) target_io = BytesIO() target.save(target_io, "JPEG") self.pics = self.pics[matrix[0] * matrix[1] :] self.pics.insert(0, target_io.getvalue()) async def generate_text_messages(self) -> list[MessageSegment]: if self._message is None: await self._pic_merge() msg_segments: list[MessageSegment] = [] text = "" if self.text: text += "{}".format( self.text if len(self.text) < 500 else self.text[:500] + "..." ) if text: text += "\n" text += "来源: {}".format(self.target_type) if self.target_name: text += " {}".format(self.target_name) if self.url: text += " \n详情: {}".format(self.url) msg_segments.append(MessageSegment.text(text)) for pic in self.pics: msg_segments.append(MessageSegment.image(pic)) self._message = msg_segments return self._message async def generate_pic_messages(self) -> list[MessageSegment]: if self._pic_message is None: await self._pic_merge() msg_segments: list[MessageSegment] = [] text = "" if self.text: text += "{}".format(self.text) text += "\n" text += "来源: {}".format(self.target_type) if self.target_name: text += " {}".format(self.target_name) msg_segments.append(await parse_text(text)) if not self.target_type == "rss" and self.url: msg_segments.append(MessageSegment.text(self.url)) for pic in self.pics: msg_segments.append(MessageSegment.image(pic)) self._pic_message = msg_segments return self._pic_message def __str__(self): return "type: {}\nfrom: {}\ntext: {}\nurl: {}\npic: {}".format( self.target_type, self.target_name, self.text if len(self.text) < 500 else self.text[:500] + "...", self.url, ", ".join( map( lambda x: "b64img" if isinstance(x, bytes) or x.startswith("base64") else x, self.pics, ) ), ) @dataclass class Post(AbstractPost, _Post): pass