Merge remote-tracking branch 'origin/main' into wr

# Conflicts:
#	src/main/java/top/suyiiyii/sims/controller/UserController.java
#	src/main/java/top/suyiiyii/sims/service/UserService.java
This commit is contained in:
suyiiyii 2024-09-07 17:34:02 +08:00
commit 394156a639
11 changed files with 287 additions and 62 deletions

24
pom.xml
View File

@ -95,16 +95,38 @@
<artifactId>spring-restdocs-mockmvc</artifactId>
<scope>test</scope>
</dependency>
<!-- https://mvnrepository.com/artifact/org.springdoc/springdoc-openapi-starter-webmvc-ui -->
<dependency>
<groupId>org.springdoc</groupId>
<artifactId>springdoc-openapi-starter-webmvc-ui</artifactId>
<version>2.3.0</version>
<version>2.6.0</version>
</dependency>
<dependency>
<groupId>org.xerial</groupId>
<artifactId>sqlite-jdbc</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-s3</artifactId>
<version>1.12.706</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.11.0</version>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>4.0.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-validation</artifactId>
</dependency>
</dependencies>
<build>

View File

@ -48,6 +48,8 @@ public class JwtInterceptor implements HandlerInterceptor {
}
} catch (TokenExpiredException e) {
throw new ServiceException("401", "登录已过期,请重新登录");
} catch (Exception e) {
throw new ServiceException("401", "token验证失败请重新登录");
}
// 获取 token 中的 user id
Integer userId = Integer.parseInt(Objects.requireNonNull(JwtUtils.extractUserId(token)));
@ -55,4 +57,7 @@ public class JwtInterceptor implements HandlerInterceptor {
request.setAttribute("userId", userId);
return true;
}
public static int getUserIdFromReq(HttpServletRequest request){
return (int) request.getAttribute("userId");
}
}

View File

@ -0,0 +1,17 @@
package top.suyiiyii.sims.common;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import top.suyiiyii.sims.utils.S3Client;
@Configuration
public class S3Config {
@Bean
public static S3Client Config(@Value("${S3.ENDPOINT}") String endpoint,
@Value("${S3.ACCESS_KEY}") String accessKey,
@Value("${S3.SECRET_KEY}") String secretKey,
@Value("${S3.BUCKET}") String bucket) {
return new S3Client(endpoint, accessKey, secretKey, bucket);
}
}

View File

@ -0,0 +1,40 @@
package top.suyiiyii.sims.controller;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
import top.suyiiyii.sims.common.AuthAccess;
import top.suyiiyii.sims.service.FileService;
import java.io.IOException;
import java.io.InputStream;
@Slf4j
@RestController
public class FileController {
@Autowired
FileService fileService;
@AuthAccess(allowRoles = {"user"})
@Operation(summary = "上传文件", description = "使用form-data格式\nfile文件\nfilename 文件名\n返回可访问的路径")
@PostMapping("/upload")
public String uploadFile(
@Parameter String filename,
@RequestBody(content = @Content(mediaType = "multipart/form-data",
schema = @Schema(type = "string", format = "binary"))) MultipartFile file) {
try (InputStream in = file.getInputStream()) {
log.info("文件上传,文件名:{},描述:{}", file.getOriginalFilename(), filename);
return fileService.uploadFile(in, filename);
} catch (IOException e) {
log.warn("文件上传失败", e);
throw new RuntimeException(e);
}
}
}

View File

@ -1,19 +1,25 @@
package top.suyiiyii.sims.controller;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import io.swagger.v3.oas.annotations.Operation;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpSession;
import jakarta.validation.constraints.Max;
import jakarta.validation.Valid;
import jakarta.validation.constraints.Email;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.hibernate.validator.constraints.Length;
import org.hibernate.validator.constraints.Range;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import top.suyiiyii.sims.common.AuthAccess;
import top.suyiiyii.sims.common.JwtInterceptor;
import top.suyiiyii.sims.common.Result;
import top.suyiiyii.sims.dto.CommonResponse;
import top.suyiiyii.sims.dto.UserDto;
import top.suyiiyii.sims.entity.User;
import top.suyiiyii.sims.exception.ServiceException;
import top.suyiiyii.sims.mapper.MpUserMapper;
import top.suyiiyii.sims.service.RoleService;
import top.suyiiyii.sims.service.UserService;
@ -36,12 +42,14 @@ public class UserController {
@Autowired
UserService userService;
@Autowired
MpUserMapper mpUserMapper;
@Autowired
RoleService roleService;
@AuthAccess(allowRoles = {"guest"})
@PostMapping("/user/login")
public Result<LoginResponse> login(@RequestBody LoginRequest request,HttpServletRequest httpServletRequest) {
public Result<LoginResponse> login(@RequestBody LoginRequest request) {
log.info("login request:{}", request);
if (StrUtil.isBlank(request.getUsername()) || StrUtil.isBlank(request.getPassword())) {
@ -54,25 +62,28 @@ public class UserController {
}
LoginResponse response = new LoginResponse();
response.setToken(token);
HttpSession session = httpServletRequest.getSession();
session.setAttribute("token",token);
return Result.success(response);
}
@AuthAccess(allowRoles = {"guest"})
@PostMapping("/user/register")
public Result<CommonResponse> register(@RequestBody RegisterRequest request) {
public Result<CommonResponse> register(@RequestBody @Valid
RegisterRequest request) {
log.info("register request:{}", request);
if (StrUtil.isBlank(request.getUsername()) || StrUtil.isBlank(request.getPassword())) {
return Result.error("用户名或密码不能为空");
// 检查 username 是否已存在
if (mpUserMapper.selectOne(new LambdaQueryWrapper<User>(User.class).eq(User::getUsername, request.getUsername())) != null) {
throw new ServiceException("用户名已存在");
}
if (request.getPassword() == null || request.getPassword().length() < 3) {
throw new ServiceException("密码长度不能小于3位");
// 检查 studentId 是否已存在
if (mpUserMapper.selectOne(new LambdaQueryWrapper<User>(User.class).eq(User::getStudentId, request.getStudentId())) != null) {
throw new ServiceException("学号已存在");
}
// 检查 email 是否已存在
if (mpUserMapper.selectOne(new LambdaQueryWrapper<User>(User.class).eq(User::getEmail, request.getEmail())) != null) {
throw new ServiceException("邮箱已存在");
}
userService.register(request);
return Result.success(CommonResponse.factory("注册成功"));
}
@ -105,19 +116,26 @@ public class UserController {
@Operation(description = "获取当前用户信息")
@AuthAccess(allowRoles = {"user"})
@GetMapping("/user/me")
public Result<UserDto> getSelf() {
UserDto user = userService.findUser(0);
public Result<UserDto> getSelf(HttpServletRequest request) {
int userId = JwtInterceptor.getUserIdFromReq(request);
UserDto user = userService.findUser(userId);
return Result.success(user);
}
@Data
public static class RegisterRequest {
@Length(min = 3, max = 20)
private String username;
private Integer studentId;
@Length(min = 6, max = 20)
private String password;
@Range(min = 1, max = 1000000000)
private Integer studentId;
@Email
private String email;
@Length(min = 1, max = 20)
private String grade;
@Length(min = 1, max = 20)
private String userGroup;
}

View File

@ -39,10 +39,8 @@ public class User {
@UniqueIndex
@Column(comment = "邮箱", notNull = true)
private String email;
@UniqueIndex
@Column(comment = "年级", notNull = true)
private String grade;
@UniqueIndex
@Column(comment = "用户所属团队", notNull = true)
private String userGroup;
}

View File

@ -0,0 +1,18 @@
package top.suyiiyii.sims.service;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import top.suyiiyii.sims.utils.S3Client;
import java.io.InputStream;
@Service
public class FileService {
@Autowired
private S3Client s3Client;
public String uploadFile(InputStream input,String fileName) {
String extension = fileName.substring(fileName.lastIndexOf("."));
return s3Client.uploadFile(input, extension);
}
}

View File

@ -70,27 +70,9 @@ public class UserService {
public void register(UserController.RegisterRequest req) {
User dbUser = userMapper.selectByUserId(req.getStudentId());
if (req.getUsername() == null || req.getUsername().equals("")) {
throw new ServiceException("用户名不能为空");
}
if (dbUser != null) {
throw new ServiceException("账号已经存在");
}
if (req.getStudentId() == null || req.getStudentId().equals("")) {
throw new ServiceException("学号不能为空");
}
if (req.getPassword() == null || req.getPassword().equals("")) {
throw new ServiceException("密码不能为空");
}
if (req.getEmail() == null || req.getEmail().equals("")) {
throw new ServiceException("邮箱不能为空");
}
if (req.getUserGroup() == null || req.getUserGroup().equals("")) {
throw new ServiceException("组别不能为空");
}
User user = modelMapper.map(req, User.class);
mpUserMapper.insert(user);
user = mpUserMapper.selectOne(new LambdaQueryWrapper<User>().eq(User::getUsername, req.getUsername()));
@ -110,14 +92,6 @@ public class UserService {
UserDto.setUserGroup(user.getUserGroup());
UserDto.setRoles(new ArrayList<>());
Integer id = user.getId();
List<Role> roles = roleMapper.selectRolesById(id);
for (Role role : roles) {
Integer roleId = role.getId();
// 获取一个角色的名称列表
List<String> roleNameList = roleMapper.selectRoleNamesByRoleId(roleId);
// 累加角色名称到用户的角色列表中
UserDto.getRoles().addAll(roleNameList);
}
UserDtos.add(UserDto);
}
return UserDtos;
@ -127,31 +101,23 @@ public class UserService {
UserDto UserDto = new UserDto();
User user = userMapper.selectById(id);
if (user == null) {
throw new ServiceException("用户不存在");
}
UserDto.setUserId(user.getId());
UserDto.setUsername(user.getUsername());
UserDto.setGrade(user.getGrade());
UserDto.setUserGroup(user.getUserGroup());
UserDto.setRoles(new ArrayList<>());
List<Role> roles = roleMapper.selectRolesById(id);
for (Role role : roles) {
Integer roleId = role.getId();
// 获取一个角色的名称列表
List<String> roleNameList = roleMapper.selectRoleNamesByRoleId(roleId);
// 累加角色名称到用户的角色列表中
UserDto.getRoles().addAll(roleNameList);
}
//TODO: 获取用户角色
return UserDto;
}
/*
public User selectByUserId(Integer studentId) {
return userMapper.selectByUserId(studentId);
}
public List<Role> selectRolesById(Integer studentId) {
return roleMapper.selectRolesById(studentId);
}
*/
public Integer getStudentIdByUserId(Integer userId) {
return userMapper.getStudentIdByUserId(userId);
}
}

View File

@ -0,0 +1,123 @@
package top.suyiiyii.sims.utils;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.Protocol;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3Client;
import com.amazonaws.services.s3.S3ClientOptions;
import com.amazonaws.services.s3.model.ObjectMetadata;
import com.amazonaws.services.s3.model.S3Object;
import org.apache.commons.io.IOUtils;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.UUID;
public class S3Client {
private final String endpoint;
private final String bucket;
private final AmazonS3 s3client;
public S3Client(String endpoint, String accessKey, String secretKey, String bucket) {
this.endpoint = endpoint;
this.bucket = bucket;
URL endpointUrl;
try {
endpointUrl = new URL(endpoint);
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}
String protocol = endpointUrl.getProtocol();
int port = endpointUrl.getPort() == -1 ? endpointUrl.getDefaultPort() : endpointUrl.getPort();
ClientConfiguration clientConfig = new ClientConfiguration();
clientConfig.setSignerOverride("S3SignerType");
clientConfig.setProtocol(Protocol.valueOf(protocol.toUpperCase()));
// 禁用证书检查避免https自签证书校验失败
System.setProperty("com.amazonaws.sdk.disableCertChecking", "true");
// 屏蔽 AWS MD5 校验避免校验导致的下载抛出异常问题
System.setProperty("com.amazonaws.services.s3.disableGetObjectMD5Validation", "true");
AWSCredentials awsCredentials = new BasicAWSCredentials(accessKey, secretKey);
// 创建 S3Client 实例
AmazonS3 s3client = new AmazonS3Client(awsCredentials, clientConfig);
s3client.setEndpoint(endpointUrl.getHost() + ":" + port);
s3client.setS3ClientOptions(S3ClientOptions.builder().setPathStyleAccess(true).build());
this.s3client = s3client;
}
public boolean bucketExists(String bucket) {
try {
return s3client.doesBucketExist(bucket);
} catch (Exception e) {
e.printStackTrace();
}
return false;
}
public boolean existObject(String bucket, String objectId) {
try {
return s3client.doesObjectExist(bucket, objectId);
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
public InputStream download(String bucket, String objectId) {
try {
S3Object o = s3client.getObject(bucket, objectId);
return o.getObjectContent();
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
public void download(String bucket, String objectId, OutputStream out) {
S3Object o = s3client.getObject(bucket, objectId);
try (InputStream in = o.getObjectContent()) {
IOUtils.copyLarge(in, out);
} catch (Exception e) {
e.printStackTrace();
}
}
public void upload(String bucket, String objectId, InputStream input) {
try {
// 创建文件上传的元数据
ObjectMetadata meta = new ObjectMetadata();
// 设置文件上传长度
meta.setContentLength(input.available());
// 上传
s3client.putObject(bucket, objectId, input, meta);
} catch (Exception e) {
e.printStackTrace();
}
}
public String uploadFile(InputStream input) {
String objectID = UUID.randomUUID().toString();
upload(bucket, objectID, input);
return endpoint + "/" + bucket + "/" + objectID;
}
/**
* 接收文件流自动使用随机uuid命名并保留扩展名
* 返回公网上可以直接访问的URL
*
* @param input 文件流
* @param extensionName 扩展名
* @return 文件的URL
*/
public String uploadFile(InputStream input, String extensionName) {
String objectID = UUID.randomUUID() + extensionName;
upload(bucket, objectID, input);
return endpoint + "/" + bucket + "/" + objectID;
}
}

View File

@ -2,6 +2,10 @@
spring:
profiles:
active: prod
servlet:
multipart:
max-file-size: 100MB
max-request-size: 100MB
datasource:
url: ${DATASOURCE_URL}
username: ${DATASOURCE_USERNAME}
@ -14,4 +18,10 @@ auto-table:
model-package: top.suyiiyii.sims.entity
jwt:
secret: ${JWT_SECRET}
secret: ${JWT_SECRET}
S3:
ENDPOINT: ${S3_ENDPOINT}
BUCKET: ${S3_BUCKET}
ACCESS_KEY: ${S3_ACCESS_KEY}
SECRET_KEY: ${S3_SECRET_KEY}

View File

@ -1,10 +1,14 @@
package top.suyiiyii.sims.service;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.test.context.ActiveProfiles;
import top.suyiiyii.sims.entity.Role;
import top.suyiiyii.sims.utils.S3Client;
import java.util.List;
@ -12,11 +16,15 @@ import static org.junit.jupiter.api.Assertions.*;
@SpringBootTest
@ActiveProfiles("test")
@ExtendWith(MockitoExtension.class)
class RbacServiceTest {
@Autowired
private RbacService rbacService;
@MockBean
private S3Client s3Client;
@Test
void addRoleWithUserId() {
int userId = 1; // mock userId