diff --git a/pom.xml b/pom.xml index 4374986..531b24c 100644 --- a/pom.xml +++ b/pom.xml @@ -95,16 +95,38 @@ spring-restdocs-mockmvc test + org.springdoc springdoc-openapi-starter-webmvc-ui - 2.3.0 + 2.6.0 + org.xerial sqlite-jdbc test + + com.amazonaws + aws-java-sdk-s3 + 1.12.706 + + + commons-io + commons-io + 2.11.0 + + + org.mockito + mockito-core + 4.0.0 + test + + + org.springframework.boot + spring-boot-starter-validation + diff --git a/src/main/java/top/suyiiyii/sims/common/JwtInterceptor.java b/src/main/java/top/suyiiyii/sims/common/JwtInterceptor.java index 63b0254..9c634f1 100644 --- a/src/main/java/top/suyiiyii/sims/common/JwtInterceptor.java +++ b/src/main/java/top/suyiiyii/sims/common/JwtInterceptor.java @@ -48,6 +48,8 @@ public class JwtInterceptor implements HandlerInterceptor { } } catch (TokenExpiredException e) { throw new ServiceException("401", "登录已过期,请重新登录"); + } catch (Exception e) { + throw new ServiceException("401", "token验证失败,请重新登录"); } // 获取 token 中的 user id Integer userId = Integer.parseInt(Objects.requireNonNull(JwtUtils.extractUserId(token))); @@ -55,4 +57,7 @@ public class JwtInterceptor implements HandlerInterceptor { request.setAttribute("userId", userId); return true; } + public static int getUserIdFromReq(HttpServletRequest request){ + return (int) request.getAttribute("userId"); + } } diff --git a/src/main/java/top/suyiiyii/sims/common/S3Config.java b/src/main/java/top/suyiiyii/sims/common/S3Config.java new file mode 100644 index 0000000..900618a --- /dev/null +++ b/src/main/java/top/suyiiyii/sims/common/S3Config.java @@ -0,0 +1,17 @@ +package top.suyiiyii.sims.common; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import top.suyiiyii.sims.utils.S3Client; + +@Configuration +public class S3Config { + @Bean + public static S3Client Config(@Value("${S3.ENDPOINT}") String endpoint, + @Value("${S3.ACCESS_KEY}") String accessKey, + @Value("${S3.SECRET_KEY}") String secretKey, + @Value("${S3.BUCKET}") String bucket) { + return new S3Client(endpoint, accessKey, secretKey, bucket); + } +} diff --git a/src/main/java/top/suyiiyii/sims/controller/FileController.java b/src/main/java/top/suyiiyii/sims/controller/FileController.java new file mode 100644 index 0000000..d66f458 --- /dev/null +++ b/src/main/java/top/suyiiyii/sims/controller/FileController.java @@ -0,0 +1,40 @@ +package top.suyiiyii.sims.controller; + +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.media.Content; +import io.swagger.v3.oas.annotations.media.Schema; +import io.swagger.v3.oas.annotations.parameters.RequestBody; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.multipart.MultipartFile; +import top.suyiiyii.sims.common.AuthAccess; +import top.suyiiyii.sims.service.FileService; + +import java.io.IOException; +import java.io.InputStream; + +@Slf4j +@RestController +public class FileController { + @Autowired + FileService fileService; + + @AuthAccess(allowRoles = {"user"}) + @Operation(summary = "上传文件", description = "使用form-data格式\nfile:文件\nfilename: 文件名\n返回可访问的路径") + @PostMapping("/upload") + public String uploadFile( + @Parameter String filename, + @RequestBody(content = @Content(mediaType = "multipart/form-data", + schema = @Schema(type = "string", format = "binary"))) MultipartFile file) { + try (InputStream in = file.getInputStream()) { + log.info("文件上传,文件名:{},描述:{}", file.getOriginalFilename(), filename); + return fileService.uploadFile(in, filename); + } catch (IOException e) { + log.warn("文件上传失败", e); + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/top/suyiiyii/sims/controller/UserController.java b/src/main/java/top/suyiiyii/sims/controller/UserController.java index ed3b62a..d52b86f 100644 --- a/src/main/java/top/suyiiyii/sims/controller/UserController.java +++ b/src/main/java/top/suyiiyii/sims/controller/UserController.java @@ -1,19 +1,25 @@ package top.suyiiyii.sims.controller; import cn.hutool.core.util.StrUtil; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import io.swagger.v3.oas.annotations.Operation; import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpSession; -import jakarta.validation.constraints.Max; +import jakarta.validation.Valid; +import jakarta.validation.constraints.Email; import lombok.Data; import lombok.extern.slf4j.Slf4j; +import org.hibernate.validator.constraints.Length; +import org.hibernate.validator.constraints.Range; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; import top.suyiiyii.sims.common.AuthAccess; +import top.suyiiyii.sims.common.JwtInterceptor; import top.suyiiyii.sims.common.Result; import top.suyiiyii.sims.dto.CommonResponse; import top.suyiiyii.sims.dto.UserDto; +import top.suyiiyii.sims.entity.User; import top.suyiiyii.sims.exception.ServiceException; +import top.suyiiyii.sims.mapper.MpUserMapper; import top.suyiiyii.sims.service.RoleService; import top.suyiiyii.sims.service.UserService; @@ -36,12 +42,14 @@ public class UserController { @Autowired UserService userService; @Autowired + MpUserMapper mpUserMapper; + @Autowired RoleService roleService; @AuthAccess(allowRoles = {"guest"}) @PostMapping("/user/login") - public Result login(@RequestBody LoginRequest request,HttpServletRequest httpServletRequest) { + public Result login(@RequestBody LoginRequest request) { log.info("login request:{}", request); if (StrUtil.isBlank(request.getUsername()) || StrUtil.isBlank(request.getPassword())) { @@ -54,25 +62,28 @@ public class UserController { } LoginResponse response = new LoginResponse(); response.setToken(token); - HttpSession session = httpServletRequest.getSession(); - session.setAttribute("token",token); return Result.success(response); } + @AuthAccess(allowRoles = {"guest"}) @PostMapping("/user/register") - public Result register(@RequestBody RegisterRequest request) { + public Result register(@RequestBody @Valid + RegisterRequest request) { log.info("register request:{}", request); - if (StrUtil.isBlank(request.getUsername()) || StrUtil.isBlank(request.getPassword())) { - - return Result.error("用户名或密码不能为空"); + // 检查 username 是否已存在 + if (mpUserMapper.selectOne(new LambdaQueryWrapper(User.class).eq(User::getUsername, request.getUsername())) != null) { + throw new ServiceException("用户名已存在"); } - if (request.getPassword() == null || request.getPassword().length() < 3) { - throw new ServiceException("密码长度不能小于3位"); + // 检查 studentId 是否已存在 + if (mpUserMapper.selectOne(new LambdaQueryWrapper(User.class).eq(User::getStudentId, request.getStudentId())) != null) { + throw new ServiceException("学号已存在"); + } + // 检查 email 是否已存在 + if (mpUserMapper.selectOne(new LambdaQueryWrapper(User.class).eq(User::getEmail, request.getEmail())) != null) { + throw new ServiceException("邮箱已存在"); } - userService.register(request); - return Result.success(CommonResponse.factory("注册成功")); } @@ -105,19 +116,26 @@ public class UserController { @Operation(description = "获取当前用户信息") @AuthAccess(allowRoles = {"user"}) @GetMapping("/user/me") - public Result getSelf() { - UserDto user = userService.findUser(0); + public Result getSelf(HttpServletRequest request) { + int userId = JwtInterceptor.getUserIdFromReq(request); + UserDto user = userService.findUser(userId); return Result.success(user); } @Data public static class RegisterRequest { + @Length(min = 3, max = 20) private String username; - private Integer studentId; + @Length(min = 6, max = 20) private String password; + @Range(min = 1, max = 1000000000) + private Integer studentId; + @Email private String email; + @Length(min = 1, max = 20) private String grade; + @Length(min = 1, max = 20) private String userGroup; } diff --git a/src/main/java/top/suyiiyii/sims/entity/User.java b/src/main/java/top/suyiiyii/sims/entity/User.java index 3fd02d9..27d2025 100644 --- a/src/main/java/top/suyiiyii/sims/entity/User.java +++ b/src/main/java/top/suyiiyii/sims/entity/User.java @@ -39,10 +39,8 @@ public class User { @UniqueIndex @Column(comment = "邮箱", notNull = true) private String email; - @UniqueIndex @Column(comment = "年级", notNull = true) private String grade; - @UniqueIndex @Column(comment = "用户所属团队", notNull = true) private String userGroup; } diff --git a/src/main/java/top/suyiiyii/sims/service/FileService.java b/src/main/java/top/suyiiyii/sims/service/FileService.java new file mode 100644 index 0000000..f632a56 --- /dev/null +++ b/src/main/java/top/suyiiyii/sims/service/FileService.java @@ -0,0 +1,18 @@ +package top.suyiiyii.sims.service; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import top.suyiiyii.sims.utils.S3Client; + +import java.io.InputStream; + +@Service +public class FileService { + @Autowired + private S3Client s3Client; + + public String uploadFile(InputStream input,String fileName) { + String extension = fileName.substring(fileName.lastIndexOf(".")); + return s3Client.uploadFile(input, extension); + } +} diff --git a/src/main/java/top/suyiiyii/sims/service/UserService.java b/src/main/java/top/suyiiyii/sims/service/UserService.java index 5d2aafa..c47172b 100644 --- a/src/main/java/top/suyiiyii/sims/service/UserService.java +++ b/src/main/java/top/suyiiyii/sims/service/UserService.java @@ -70,27 +70,9 @@ public class UserService { public void register(UserController.RegisterRequest req) { User dbUser = userMapper.selectByUserId(req.getStudentId()); - - if (req.getUsername() == null || req.getUsername().equals("")) { - throw new ServiceException("用户名不能为空"); - } - if (dbUser != null) { throw new ServiceException("账号已经存在"); } - if (req.getStudentId() == null || req.getStudentId().equals("")) { - throw new ServiceException("学号不能为空"); - } - if (req.getPassword() == null || req.getPassword().equals("")) { - - throw new ServiceException("密码不能为空"); - } - if (req.getEmail() == null || req.getEmail().equals("")) { - throw new ServiceException("邮箱不能为空"); - } - if (req.getUserGroup() == null || req.getUserGroup().equals("")) { - throw new ServiceException("组别不能为空"); - } User user = modelMapper.map(req, User.class); mpUserMapper.insert(user); user = mpUserMapper.selectOne(new LambdaQueryWrapper().eq(User::getUsername, req.getUsername())); @@ -110,14 +92,6 @@ public class UserService { UserDto.setUserGroup(user.getUserGroup()); UserDto.setRoles(new ArrayList<>()); Integer id = user.getId(); - List roles = roleMapper.selectRolesById(id); - for (Role role : roles) { - Integer roleId = role.getId(); - // 获取一个角色的名称列表 - List roleNameList = roleMapper.selectRoleNamesByRoleId(roleId); - // 累加角色名称到用户的角色列表中 - UserDto.getRoles().addAll(roleNameList); - } UserDtos.add(UserDto); } return UserDtos; @@ -127,31 +101,23 @@ public class UserService { UserDto UserDto = new UserDto(); User user = userMapper.selectById(id); + if (user == null) { + throw new ServiceException("用户不存在"); + } UserDto.setUserId(user.getId()); UserDto.setUsername(user.getUsername()); UserDto.setGrade(user.getGrade()); UserDto.setUserGroup(user.getUserGroup()); UserDto.setRoles(new ArrayList<>()); - List roles = roleMapper.selectRolesById(id); - for (Role role : roles) { - Integer roleId = role.getId(); - // 获取一个角色的名称列表 - List roleNameList = roleMapper.selectRoleNamesByRoleId(roleId); - // 累加角色名称到用户的角色列表中 - UserDto.getRoles().addAll(roleNameList); - } - - + //TODO: 获取用户角色 return UserDto; } -/* + public User selectByUserId(Integer studentId) { + return userMapper.selectByUserId(studentId); + } + public List selectRolesById(Integer studentId) { return roleMapper.selectRolesById(studentId); } -*/ - - public Integer getStudentIdByUserId(Integer userId) { - return userMapper.getStudentIdByUserId(userId); - } } diff --git a/src/main/java/top/suyiiyii/sims/utils/S3Client.java b/src/main/java/top/suyiiyii/sims/utils/S3Client.java new file mode 100644 index 0000000..837f121 --- /dev/null +++ b/src/main/java/top/suyiiyii/sims/utils/S3Client.java @@ -0,0 +1,123 @@ +package top.suyiiyii.sims.utils; + + +import com.amazonaws.ClientConfiguration; +import com.amazonaws.Protocol; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.BasicAWSCredentials; +import com.amazonaws.services.s3.AmazonS3; +import com.amazonaws.services.s3.AmazonS3Client; +import com.amazonaws.services.s3.S3ClientOptions; +import com.amazonaws.services.s3.model.ObjectMetadata; +import com.amazonaws.services.s3.model.S3Object; +import org.apache.commons.io.IOUtils; + +import java.io.InputStream; +import java.io.OutputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.util.UUID; + + +public class S3Client { + private final String endpoint; + private final String bucket; + private final AmazonS3 s3client; + + public S3Client(String endpoint, String accessKey, String secretKey, String bucket) { + this.endpoint = endpoint; + this.bucket = bucket; + URL endpointUrl; + try { + endpointUrl = new URL(endpoint); + } catch (MalformedURLException e) { + throw new RuntimeException(e); + } + String protocol = endpointUrl.getProtocol(); + int port = endpointUrl.getPort() == -1 ? endpointUrl.getDefaultPort() : endpointUrl.getPort(); + ClientConfiguration clientConfig = new ClientConfiguration(); + clientConfig.setSignerOverride("S3SignerType"); + clientConfig.setProtocol(Protocol.valueOf(protocol.toUpperCase())); + // 禁用证书检查,避免https自签证书校验失败 + System.setProperty("com.amazonaws.sdk.disableCertChecking", "true"); + // 屏蔽 AWS 的 MD5 校验,避免校验导致的下载抛出异常问题 + System.setProperty("com.amazonaws.services.s3.disableGetObjectMD5Validation", "true"); + AWSCredentials awsCredentials = new BasicAWSCredentials(accessKey, secretKey); + // 创建 S3Client 实例 + AmazonS3 s3client = new AmazonS3Client(awsCredentials, clientConfig); + s3client.setEndpoint(endpointUrl.getHost() + ":" + port); + s3client.setS3ClientOptions(S3ClientOptions.builder().setPathStyleAccess(true).build()); + this.s3client = s3client; + } + + + public boolean bucketExists(String bucket) { + try { + return s3client.doesBucketExist(bucket); + } catch (Exception e) { + e.printStackTrace(); + } + return false; + } + + public boolean existObject(String bucket, String objectId) { + try { + return s3client.doesObjectExist(bucket, objectId); + } catch (Exception e) { + e.printStackTrace(); + return false; + } + } + + public InputStream download(String bucket, String objectId) { + try { + S3Object o = s3client.getObject(bucket, objectId); + return o.getObjectContent(); + } catch (Exception e) { + e.printStackTrace(); + } + return null; + } + + public void download(String bucket, String objectId, OutputStream out) { + S3Object o = s3client.getObject(bucket, objectId); + try (InputStream in = o.getObjectContent()) { + IOUtils.copyLarge(in, out); + } catch (Exception e) { + e.printStackTrace(); + } + } + + public void upload(String bucket, String objectId, InputStream input) { + try { + // 创建文件上传的元数据 + ObjectMetadata meta = new ObjectMetadata(); + // 设置文件上传长度 + meta.setContentLength(input.available()); + // 上传 + s3client.putObject(bucket, objectId, input, meta); + } catch (Exception e) { + e.printStackTrace(); + } + } + + public String uploadFile(InputStream input) { + String objectID = UUID.randomUUID().toString(); + upload(bucket, objectID, input); + return endpoint + "/" + bucket + "/" + objectID; + } + + /** + * 接收文件流,自动使用随机uuid命名并保留扩展名 + * 返回公网上可以直接访问的URL + * + * @param input 文件流 + * @param extensionName 扩展名 + * @return 文件的URL + */ + public String uploadFile(InputStream input, String extensionName) { + String objectID = UUID.randomUUID() + extensionName; + upload(bucket, objectID, input); + return endpoint + "/" + bucket + "/" + objectID; + } +} diff --git a/src/main/resources/application.yaml b/src/main/resources/application.yaml index 4b7599a..d54a8ca 100644 --- a/src/main/resources/application.yaml +++ b/src/main/resources/application.yaml @@ -2,6 +2,10 @@ spring: profiles: active: prod + servlet: + multipart: + max-file-size: 100MB + max-request-size: 100MB datasource: url: ${DATASOURCE_URL} username: ${DATASOURCE_USERNAME} @@ -14,4 +18,10 @@ auto-table: model-package: top.suyiiyii.sims.entity jwt: - secret: ${JWT_SECRET} \ No newline at end of file + secret: ${JWT_SECRET} + +S3: + ENDPOINT: ${S3_ENDPOINT} + BUCKET: ${S3_BUCKET} + ACCESS_KEY: ${S3_ACCESS_KEY} + SECRET_KEY: ${S3_SECRET_KEY} diff --git a/src/test/java/top/suyiiyii/sims/service/RbacServiceTest.java b/src/test/java/top/suyiiyii/sims/service/RbacServiceTest.java index b06cec6..f46b2ae 100644 --- a/src/test/java/top/suyiiyii/sims/service/RbacServiceTest.java +++ b/src/test/java/top/suyiiyii/sims/service/RbacServiceTest.java @@ -1,10 +1,14 @@ package top.suyiiyii.sims.service; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.mock.mockito.MockBean; import org.springframework.test.context.ActiveProfiles; import top.suyiiyii.sims.entity.Role; +import top.suyiiyii.sims.utils.S3Client; import java.util.List; @@ -12,11 +16,15 @@ import static org.junit.jupiter.api.Assertions.*; @SpringBootTest @ActiveProfiles("test") +@ExtendWith(MockitoExtension.class) class RbacServiceTest { @Autowired private RbacService rbacService; + @MockBean + private S3Client s3Client; + @Test void addRoleWithUserId() { int userId = 1; // mock userId