This commit is contained in:
2024-11-27 23:22:08 +08:00
commit 28c518b355
108 changed files with 30312 additions and 0 deletions
+206
View File
@@ -0,0 +1,206 @@
.DS_Store
node_modules
*.log
explorations
TODOs.md
RELEASE_NOTE*.md
packages/server-renderer/basic.js
packages/server-renderer/build.dev.js
packages/server-renderer/build.prod.js
packages/server-renderer/server-plugin.js
packages/server-renderer/client-plugin.js
packages/template-compiler/build.js
packages/template-compiler/browser.js
.vscode
dist
temp
types/v3-generated.d.ts
# Compiled class file
*.class
# Log file
*.log
# BlueJ files
*.ctxt
# Mobile Tools for Java (J2ME)
.mtj.tmp/
# Package Files #
*.jar
*.war
*.nar
*.ear
*.zip
*.tar.gz
*.rar
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
hs_err_pid*
replay_pid*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
+33
View File
@@ -0,0 +1,33 @@
HELP.md
target/
!.mvn/wrapper/maven-wrapper.jar
!**/src/main/**/target/
!**/src/test/**/target/
### STS ###
.apt_generated
.classpath
.factorypath
.project
.settings
.springBeans
.sts4-cache
### IntelliJ IDEA ###
.idea
*.iws
*.iml
*.ipr
### NetBeans ###
/nbproject/private/
/nbbuild/
/dist/
/nbdist/
/.nb-gradle/
build/
!**/src/main/**/build/
!**/src/test/**/build/
### VS Code ###
.vscode/
+86
View File
@@ -0,0 +1,86 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.6.11</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>com.common</groupId>
<artifactId>backend2</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>backend2</name>
<description>backend2</description>
<properties>
<java.version>1.8</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-jdbc</artifactId>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-boot-starter</artifactId>
<version>2.2.0</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-lang3 -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.9</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
@@ -0,0 +1,13 @@
package com.common.backend2;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class Backend2Application {
public static void main(String[] args) {
SpringApplication.run(Backend2Application.class, args);
}
}
@@ -0,0 +1,63 @@
package com.common.backend2.config;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.List;
public class CommonInterceptor implements HandlerInterceptor {
private List<String> excludedUrls;
public List<String> getExcludedUrls() {
return excludedUrls;
}
public void setExcludedUrls(List<String> excludedUrls) {
this.excludedUrls = excludedUrls;
}
/**
*
* 在业务处理器处理请求之前被调用 如果返回false
* 从当前的拦截器往回执行所有拦截器的afterCompletion(),
* 再退出拦截器链, 如果返回true 执行下一个拦截器,
* 直到所有的拦截器都执行完毕 再执行被拦截的Controller
* 然后进入拦截器链,
* 从最后一个拦截器往回执行所有的postHandle()
* 接着再从最后一个拦截器往回执行所有的afterCompletion()
* @param request
* @param response
*/
public boolean preHandle(HttpServletRequest request, HttpServletResponse response,
Object handler) throws Exception {
response.setHeader("Access-Control-Allow-Origin", "*");
response.setHeader("Access-Control-Allow-Methods", "*");
response.setHeader("Access-Control-Max-Age", "3600");
response.setHeader("Access-Control-Allow-Headers",
"Origin, X-Requested-With, Content-Type, Accept");
return true;
}
// 在业务处理器处理请求执行完成后,生成视图之前执行的动作
public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler,
ModelAndView modelAndView) throws Exception {
}
/**
* 在DispatcherServlet完全处理完请求后被调用
* 当有拦截器抛出异常时,
* 会从当前拦截器往回执行所有的拦截器的afterCompletion()
* @param request
* @param response
* @param handler
*
*/
public void afterCompletion(HttpServletRequest request, HttpServletResponse response,
Object handler, Exception ex) throws Exception {
}
}
@@ -0,0 +1,25 @@
package com.common.backend2.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.web.client.RestTemplate;
@Configuration
public class ResTemplateConfig {
@Bean
public RestTemplate restTemplate(ClientHttpRequestFactory factory) {
return new RestTemplate(factory);
}
@Bean
public ClientHttpRequestFactory simpleClientHttpRequestFactory() {
SimpleClientHttpRequestFactory factory = new SimpleClientHttpRequestFactory();
//超时设置
factory.setReadTimeout(10000);//ms
factory.setConnectTimeout(15000);//ms
return factory;
}
}
@@ -0,0 +1,55 @@
package com.common.backend2.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import java.io.File;
@Configuration
public class WebMvcConfig implements WebMvcConfigurer {
@Override
public void addCorsMappings(CorsRegistry registry) {
registry.addMapping("/**")
.allowedHeaders("*")
.allowedOriginPatterns("*")
.allowedMethods("POST", "GET", "PUT", "OPTIONS", "DELETE")
.maxAge(3600)
.allowCredentials(true);
}
@Override
public void addInterceptors(InterceptorRegistry registry) {
WebMvcConfigurer.super.addInterceptors(registry);
registry.addInterceptor(CommonInterceptor()).addPathPatterns("/**");
}
@Override
public void addResourceHandlers(ResourceHandlerRegistry registry) {
/**
* 资源映射路径
* addResourceHandler:访问映射路径
* addResourceLocations:资源绝对路径
*/
//String projectRootDirectoryPath = System.getProperty("user.dir");
// 通过 File 对象的 getParent() 方法获取到根目录的上级目录
//String parentPath = new File(projectRootDirectoryPath).getParent();
// registry.addResourceHandler("/result/**")
// .addResourceLocations("file:///"+parentPath+"/RED-CNN-master(Lite)\\save/fig/");
registry.addResourceHandler("/result/**")
.addResourceLocations("file:///E:/Document/project/temp/save/");
registry.addResourceHandler("/assets/**")
.addResourceLocations("file:///E:/Document/project/temp/assets/");
}
@Bean
public CommonInterceptor CommonInterceptor() {
return new CommonInterceptor();
}
}
@@ -0,0 +1,109 @@
package com.common.backend2.controller;
import com.common.backend2.entity.File;
import com.common.backend2.response.Result;
import com.common.backend2.response.ResponseCode;
import com.common.backend2.service.FileService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.LinkedList;
import java.util.Queue;
@CrossOrigin
@RestController
@RequestMapping("/api")
public class FileController {
@Autowired
private FileService fileService;
Queue<String> queue =new LinkedList(); //test_patient
@RequestMapping(value="/multi/uploadMultiImage",method=RequestMethod.POST)
public Result uploadMultiImage(@RequestParam("files") MultipartFile[] files,HttpServletRequest request){
//files 就是前端传来的多文件数组
if (files.length<=0) {
return new Result(ResponseCode.FILE_EMPTY.getCode(), ResponseCode.FILE_EMPTY.getMsg(), null);
}
for(MultipartFile file :files){
fileService.upLoadFiles(file,request);
}
return new Result(ResponseCode.SUCCESS.getCode(), ResponseCode.SUCCESS.getMsg(), "数据上传成功");
}
@RequestMapping(value = "/upload",method = RequestMethod.POST)
public Result upLoadFile(@RequestParam("file") MultipartFile multipartFile,HttpServletRequest request) {
if (multipartFile.isEmpty()) {
return new Result(ResponseCode.FILE_EMPTY.getCode(), ResponseCode.FILE_EMPTY.getMsg(), null);
}
return fileService.upLoadFiles(multipartFile,request);
}
@RequestMapping(value="/predict/{fileId}",method=RequestMethod.GET)
public Result predictAndreturnResult (@PathVariable("fileId")Integer id){
return fileService.processFile(id);
}
@RequestMapping(value = "/download/{id}",method = RequestMethod.GET)
public void downloadFiles(@PathVariable("id") String id, HttpServletRequest request, HttpServletResponse response){
OutputStream outputStream=null;
InputStream inputStream=null;
BufferedInputStream bufferedInputStream=null;
byte[] bytes=new byte[1024];
File file = fileService.getFileById(id);
String fileName = file.getFileName();
// 获取输出流
try {
response.setHeader("Content-Disposition", "attachment;filename=" + new String(fileName.getBytes(StandardCharsets.UTF_8), StandardCharsets.ISO_8859_1));
response.setContentType("application/force-download");
inputStream=fileService.getFileInputStream(file);
bufferedInputStream=new BufferedInputStream(inputStream);
outputStream = response.getOutputStream();
int i=bufferedInputStream.read(bytes);
while (i!=-1){
outputStream.write(bytes,0,i);
i=bufferedInputStream.read(bytes);
}
} catch (IOException e) {
e.printStackTrace();
}finally {
try {
if (inputStream!=null){
inputStream.close();
}
if (outputStream!=null){
outputStream.close();
}
if (bufferedInputStream!=null){
bufferedInputStream.close();
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
@RequestMapping(value="/getResult",method=RequestMethod.GET)
public Result getUrlsByIp(HttpServletRequest request,HttpServletResponse response){
return fileService.returnUrls(request,response);
}
@RequestMapping(value="/delete/{id}",method=RequestMethod.DELETE)
public Result deleteOriginAndProcess(@PathVariable("id")Integer id ){
return fileService.deleteOne(id);
}
}
@@ -0,0 +1,56 @@
package com.common.backend2.entity;
import com.baomidou.mybatisplus.annotations.TableId;
import com.baomidou.mybatisplus.annotations.TableName;
import com.baomidou.mybatisplus.enums.IdType;
import lombok.*;
import java.io.Serializable;
@Data
@AllArgsConstructor
@NoArgsConstructor
@Getter
@Setter
@EqualsAndHashCode
@TableName(value = "file")
public class File implements Serializable {
private static final long serialVersionUID=1L;
/**
* 文件id
*/
@TableId(value="id",type = IdType.AUTO)
private Integer fileId;
/**
* 文件存储路径
*/
private String filePath;
/**
* 文件存储路径
*/
private String savePath;
/**
* 文件名称
*/
private String fileName;
/**
*
* 是否处理过
*
*/
private Integer processed;
private String result;
public File(Object o, String saveChildPath, String resultChildPath, String fileName) {
this.fileId=null;
this.filePath= saveChildPath;
this.savePath = resultChildPath;
this.fileName = fileName;
}
}
@@ -0,0 +1,25 @@
package com.common.backend2.mapper;
import com.baomidou.mybatisplus.mapper.BaseMapper;
import com.common.backend2.entity.File;
import org.apache.ibatis.annotations.Insert;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Select;
import org.springframework.stereotype.Repository;
import java.util.List;
@Mapper
@Repository
public interface FileMapper extends BaseMapper<File> {
@Select("select * from file")
List<File>findAll();
@Select("SELECT LAST_INSERT_ID();")
Integer getId();
}
@@ -0,0 +1,53 @@
package com.common.backend2.response;
public enum ResponseCode {
// 系统模块
SUCCESS(0, "操作成功"),
ERROR(1, "操作失败"),
SERVER_ERROR(500, "服务器异常"),
// 通用模块 1xxxx
ILLEGAL_ARGUMENT(10000, "参数不合法"),
REPETITIVE_OPERATION(10001, "请勿重复操作"),
ACCESS_LIMIT(10002, "请求太频繁, 请稍后再试"),
MAIL_SEND_SUCCESS(10003, "邮件发送成功"),
// 用户模块 2xxxx
NEED_LOGIN(20001, "登录失效"),
USERNAME_OR_PASSWORD_EMPTY(20002, "用户名或密码不能为空"),
USERNAME_OR_PASSWORD_WRONG(20003, "用户名或密码错误"),
USER_NOT_EXISTS(20004, "用户不存在"),
WRONG_PASSWORD(20005, "密码错误"),
// 文件模块 3xxxx
FILE_EMPTY(30001,"文件不能空"),
FILE_NAME_EMPTY(30002,"文件名称不能为空"),
FILE_MAX_SIZE(30003,"文件大小超出"),
;
ResponseCode(Integer code, String msg) {
this.code = code;
this.msg = msg;
}
private Integer code;
private String msg;
public Integer getCode() {
return code;
}
public void setCode(Integer code) {
this.code = code;
}
public String getMsg() {
return msg;
}
public void setMsg(String msg) {
this.msg = msg;
}
}
@@ -0,0 +1,16 @@
package com.common.backend2.response;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class Result {
private int code;
private String message;
private Object data;
}
@@ -0,0 +1,52 @@
package com.common.backend2.service;
import com.baomidou.mybatisplus.service.IService;
import com.common.backend2.entity.File;
import com.common.backend2.response.Result;
import org.springframework.web.multipart.MultipartFile;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.InputStream;
public interface FileService extends IService<File> {
/**
* 文件上传接口
* @param file
* @return
*/
Result upLoadFiles(MultipartFile file, HttpServletRequest request);
/**
* 根据id获取文件
* @param id
* @return
*/
File getFileById(String id);
/**
* 根据id获取数据流
* @param file
* @return
*/
InputStream getFileInputStream(File file);
/**
* 对文件进行处理
* @param
* @return: 返回处理完的文件的id
*/
Result processFile(Integer id);
Result returnUrls(HttpServletRequest request, HttpServletResponse response);
Result deleteOne(Integer id);
}
@@ -0,0 +1,209 @@
package com.common.backend2.service.Impl;
import com.baomidou.mybatisplus.service.impl.ServiceImpl;
import com.common.backend2.entity.File;
import com.common.backend2.mapper.FileMapper;
import com.common.backend2.response.ResponseCode;
import com.common.backend2.response.Result;
import com.common.backend2.service.FileService;
import com.common.backend2.utils.IpUtil;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.multipart.MultipartFile;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.*;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
@Service
public class FileServiceImpl extends ServiceImpl<FileMapper, File> implements FileService {
@Value("${file.data-path}")
private String dataPath;
// @Value("${file.parent-path}")
// private String parentPath;
@Value("${file.result-path}")
private String resultPath;
@Value("${url.get-result-url}")
private String getResultUrl;
@Autowired
private FileMapper fileMapper;
@Autowired
private RestTemplateMethods restTemplateMethods;
@Transactional
public Result upLoadFiles(MultipartFile file, HttpServletRequest request) {
long MAX_SIZE = 2097152L;
String fileName = file.getOriginalFilename(); //获取原名
if (StringUtils.isEmpty(fileName)) {
return new Result(ResponseCode.FILE_NAME_EMPTY.getCode(), ResponseCode.FILE_NAME_EMPTY.getMsg(), null);
} // 判断不为空
if (file.getSize() > MAX_SIZE) {
return new Result(ResponseCode.FILE_MAX_SIZE.getCode(), ResponseCode.FILE_MAX_SIZE.getMsg(), null);
} // 判断大小是否超出限制
// String suffixName = fileName.contains(".") ? fileName.substring(fileName.lastIndexOf(".")) : null; //获取没有后缀的文件名
// 获取当前时间戳作为文件名
Date date = new Date();
SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd-HH:mm:ss");
String dateString = formatter.format(date);
String saveChildPath = dataPath ;//+ "/" + dateString; // 使用时间戳字符串存储
String resultChildPath = resultPath ;//+ "/" + dateString;
java.io.File newFile = new java.io.File(saveChildPath, fileName); //新建一个文件进行存储
if (!newFile.getParentFile().exists()) {
newFile.getParentFile().mkdirs();
} //文件创建
try {
//文件写入--文件的转存
file.transferTo(newFile);
} catch (IOException e) {
e.printStackTrace();
} // 转存到文件中
File files = new File(null, saveChildPath, resultChildPath, fileName); // 信息包装成对象准备存入数据库
fileMapper.insert(files); //插入到数据库
int fildId = fileMapper.getId();
return new Result(ResponseCode.SUCCESS.getCode(), ResponseCode.SUCCESS.getMsg(), fildId);
}
public File getFileById(String id) {
return fileMapper.selectById(id);
}
public InputStream getFileInputStream(File files) {
java.io.File file = new java.io.File(files.getFilePath());
try {
return new FileInputStream(file);
} catch (FileNotFoundException e) {
e.printStackTrace();
}
return null;
}
public Result processFile(Integer id) {
//
// File file = fileMapper.selectById(id);
// String fileName = file.getFileName();
// String classAndScore=restTemplateMethods.RestTemplatePost(fileName);
// String resultUrl=getResultUrl+"/result/"+fileName;
// ArrayList result =new ArrayList<>();
// result.add(resultUrl);
// result.add(classAndScore);
// //从结果文件夹发回--这个可以用一个get来请求或者使用轮询
// return new Result(ResponseCode.SUCCESS.getCode(), ResponseCode.SUCCESS.getMsg(), result);
File file = fileMapper.selectById(id);
Integer processed = file.getProcessed();
String fileName = file.getFileName();
ArrayList result =new ArrayList<>();
String resultUrl=getResultUrl+"/result/"+fileName;
result.add(resultUrl);
if(processed!=0){
result.add(file.getResult());
}
else{
String classAndScore=restTemplateMethods.RestTemplatePost(fileName);
//将结果字符串保存至数据库
file.setResult(classAndScore);
file.setProcessed(1);
fileMapper.updateById(file);
result.add(classAndScore);
//从结果文件夹发回--这个可以用一个get来请求或者使用轮询
}
return new Result(ResponseCode.SUCCESS.getCode(), ResponseCode.SUCCESS.getMsg(), result);
}
@Override
public Result returnUrls(HttpServletRequest request, HttpServletResponse response) {
String ip = IpUtil.getClientIpAddr(request);
String path = resultPath + "/" + ip; //python 图片目录
String RetPath = getResultUrl + "/result/" + ip + "/";
java.io.File file = new java.io.File(path);
java.io.File[] files = file.listFiles(); //获取所有的文件名
List<String> ResultUrls = new ArrayList<>();
for (java.io.File file1 : files) {
ResultUrls.add(RetPath + file1.getName());
}
return new Result(ResponseCode.SUCCESS.getCode(), ResponseCode.SUCCESS.getMsg(), ResultUrls);
}
@Override
public Result deleteOne(Integer id) {
File file =fileMapper.selectById(id);
// 获取路径
String originPath = file.getFilePath();
String resultPath = file.getSavePath();
String fileName = file.getFileName();
String originFullPath = originPath+"/"+fileName;
String resultFullPath = resultPath+"/"+fileName;
//删除文件
java.io.File originFile = new java.io.File(originFullPath);
java.io.File resultFile = new java.io.File(resultFullPath);
// 检查文件是否存在
if (originFile.exists()) {
// 尝试删除文件
boolean deleted = originFile.delete();
if (deleted) {
System.out.println("Origin文件已成功删除");
} else {
System.out.println("无法删除Origin文件");
return new Result(ResponseCode.ERROR.getCode(), ResponseCode.ERROR.getMsg(), "Origin删除失败");
}
} else {
System.out.println("Origin文件不存在");
return new Result(ResponseCode.ERROR.getCode(), ResponseCode.ERROR.getMsg(), "Origin文件不存在");
}
if (resultFile.exists()) {
// 尝试删除文件
boolean deleted = resultFile.delete();
if (deleted) {
System.out.println("Result文件已成功删除");
} else {
System.out.println("无法删除Result文件");
return new Result(ResponseCode.ERROR.getCode(), ResponseCode.ERROR.getMsg(), "Result删除失败");
}
} else {
System.out.println("Result文件不存在");
}
//删除数据库数据
int result = fileMapper.deleteById(file);
if (result > 0) {
// 删除成功
return new Result(ResponseCode.SUCCESS.getCode(), ResponseCode.SUCCESS.getMsg(), "删除成功");
} else {
// 删除失败
return new Result(ResponseCode.ERROR.getCode(), ResponseCode.ERROR.getMsg(), "数据库删除失败");
}
}
}
@@ -0,0 +1,51 @@
package com.common.backend2.service.Impl;
import jdk.nashorn.internal.parser.JSONParser;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.web.client.RestTemplate;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
@Service
public class RestTemplateMethods {
@Autowired
private RestTemplate restTemplate;
@Value("${url.predict-url}")
String predictUrl ;
public String RestTemplatePost(String fileName) {
LinkedMultiValueMap<String, String> request = new LinkedMultiValueMap<>();
//request.set("className",dateString);
request.set("fileName", fileName);
try {
URL url1 = new URL(predictUrl);
HttpURLConnection con =(HttpURLConnection)url1.openConnection();
con.setRequestMethod("POST");
con.setDoOutput(true);
con.setDoInput(true);
con.setUseCaches(false);
con.setRequestProperty("Content-Type","application/from-data");
} catch (MalformedURLException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
//请求
String result = restTemplate.postForObject(predictUrl, request, String.class);
System.out.println(result);
return result;
}
}
@@ -0,0 +1,71 @@
package com.common.backend2.utils;
import javax.servlet.http.HttpServletRequest;
import java.net.InetAddress;
import java.net.UnknownHostException;
public class IpUtil {
/**
* @description: 获取请求的ip
* @param request
* @return: java.lang.String
*/
public static String getClientIpAddr(HttpServletRequest request) {
String ip = request.getHeader("X-Forwarded-For");
if (ip != null && ip.length() != 0 && !"unknown".equalsIgnoreCase(ip)) {
// 多次反向代理后会有多个ip值,第一个ip才是真实ip
int index = ip.indexOf(',');
if (index != -1) {
ip = ip.substring(0, index);
}
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_CLUSTER_CLIENT_IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_FORWARDED_FOR");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_FORWARDED");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_VIA");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("REMOTE_ADDR");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("X-Real-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
if (ip.equals("127.0.0.1") || ip.equals("0:0:0:0:0:0:0:1")) {
//根据网卡取本机配置的IP
InetAddress inet = null;
try {
inet = InetAddress.getLocalHost();
} catch (UnknownHostException e) {
e.printStackTrace();
}
ip = inet.getHostAddress();
}
}
return ip;
}
}
@@ -0,0 +1,41 @@
server:
port: 5000
#port: 80
servlet:
encoding:
charset: utf-8
enabled: true
force: true
spring:
datasource:
driver-class-name: com.mysql.cj.jdbc.Driver
name: defaultDataSource
# url: jdbc:mysql://175.178.186.100:31441/DiatomRecognition?useUnicode=true&characterEncoding=UTF-8&serverTimezone=UTC
url: jdbc:mysql://192.168.161.130:3306/DiatomRecognition?useUnicode=true&characterEncoding=UTF-8&serverTimezone=UTC
username: root
# password: xiyi409@
password: 123456
mybatis-plus:
typeAliasesPackage: com.common.backend2.entity
mapperLocations: classpath:mapper/*.xml
file: # 服务器
data-path: /home/hejinwen/detect/data/upload
result-path: /home/hejinwen/detect/data/save
url: # 服务器
predict-url: http://127.0.0.1:2000/predict
get-result-url: http://192.168.161.130:80 # 走nginx的代理
#url: # 服务器测试
# predict-url: http://175.178.186.100:5000/predict
# get-result-url: wadouri:http://175.178.186.100:6000
#file: # 本地测试
# parent-path: E:\Document\project\temp
# data-path: E:/Document/project/temp/data
# result-path: E:\Document\project\temp/save
#url: # 本地测试
# predict-url: http://127.0.0.1:2000/predict
# get-result-url: http://127.0.0.1:5000
logging:
level:
root: INFO
web: DEBUG
@@ -0,0 +1,13 @@
package com.common.backend2;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
@SpringBootTest
class Backend2ApplicationTests {
@Test
void contextLoads() {
}
}
+5
View File
@@ -0,0 +1,5 @@
module.exports = {
presets: [
'@vue/cli-plugin-babel/preset'
]
}
+19
View File
@@ -0,0 +1,19 @@
{
"compilerOptions": {
"target": "es5",
"module": "esnext",
"baseUrl": "./",
"moduleResolution": "node",
"paths": {
"@/*": [
"src/*"
]
},
"lib": [
"esnext",
"dom",
"dom.iterable",
"scripthost"
]
}
}
+15
View File
@@ -0,0 +1,15 @@
export default [
// 获取天气情况
{
method: 'get',
url: '/getWeather',
data: {
code: 200,
message: 'success',
data: {
weather: '小雨转多云',
temperature: '13℃~18℃'
}
}
}
]
+25
View File
@@ -0,0 +1,25 @@
const Mock = require('mockjs')
// 遍历所有mock文件
const files = require.context('.', true, /\.js$/)
let mockList = files.keys()
.filter(key =>
key !== './index.js' && files(key).default
)
.map(key =>
files(key).default
);
// 开始注册所有mock服务
for (let list of mockList) { //遍历所有模块
// 遍历模块中的所有api
for (let item of list) {
// 注入mock
Mock.mock(
'/api' + item.url,
item.data
)
}
}
+20270
View File
File diff suppressed because it is too large Load Diff
+51
View File
@@ -0,0 +1,51 @@
{
"name": "bigscreen",
"version": "0.1.0",
"private": true,
"scripts": {
"serve": "vue-cli-service serve",
"build": "vue-cli-service build",
"lint": "vue-cli-service lint"
},
"dependencies": {
"axios": "^1.4.0",
"core-js": "^3.8.3",
"dayjs": "^1.11.8",
"echarts": "^5.4.2",
"element-ui": "^2.15.13",
"mockjs": "^1.1.0",
"vue": "^2.6.14",
"vue-baidu-map": "^0.21.22"
},
"devDependencies": {
"@babel/core": "^7.12.16",
"@babel/eslint-parser": "^7.12.16",
"@vue/cli-plugin-babel": "~5.0.0",
"@vue/cli-plugin-eslint": "~5.0.0",
"@vue/cli-service": "~5.0.0",
"eslint": "^7.32.0",
"eslint-plugin-vue": "^8.0.3",
"less": "^4.1.3",
"less-loader": "^11.1.3",
"vue-template-compiler": "^2.6.14"
},
"eslintConfig": {
"root": true,
"env": {
"node": true
},
"extends": [
"plugin:vue/essential",
"eslint:recommended"
],
"parserOptions": {
"parser": "@babel/eslint-parser"
},
"rules": {}
},
"browserslist": [
"> 1%",
"last 2 versions",
"not dead"
]
}
Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 KiB

+38
View File
@@ -0,0 +1,38 @@
<!DOCTYPE html>
<html lang="">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width,initial-scale=1.0">
<link rel="icon" href="<%= BASE_URL %>favicon.ico">
<title>
<%= htmlWebpackPlugin.options.title %>
</title>
<style>
* {
margin: 0;
}
html,
body {
height: 100%;
}
</style>
<script>
//fontsize计算 1rem=16px
document.documentElement.style.fontSize = document.documentElement.clientWidth / 1920 * 16 + 'px'
</script>
</head>
<body>
<noscript>
<strong>We're sorry but <%= htmlWebpackPlugin.options.title %> doesn't work properly without JavaScript enabled.
Please enable it to continue.</strong>
</noscript>
<div id="app"></div>
<!-- built files will be auto injected -->
</body>
</html>
+589
View File
@@ -0,0 +1,589 @@
<template>
<div id="app">
<div class="container">
<div class="content-top">
<span class="title">硅藻识别系统</span>
<div class="time-box">
<div id="time">{{ nowTime }}</div>
<div class="line-box line-box1"></div>
</div>
<div class="weather-box">
<div class="line-box line-box2"></div>
<div id="weather">
{{ weatherObj.weather }}&nbsp;&nbsp;&nbsp;{{
weatherObj.temperature
}}
</div>
</div>
</div>
<div class="content-bottom">
<div class="content-left">
<section class="section-box roadDistribute-box">
<!-- <div class="content-bodyhead">
<div style="color: rgb(179, 179, 200); font-size: 18px;text-align: center;padding: 10px;">诊断信息输入</div>
死者姓名<input type="text" placeholder="在这里输入文本">
<br>性别<input type="text" placeholder="在这里输入文本">
<br>法医鉴定<input type="text" placeholder="在这里输入文本">
<br>送检材料<input type="text" placeholder="在这里输入文本">
<br><label>
<input type="checkbox" name="option1" value="Option 1"> 小环藻属
</label>
<label>
<input type="checkbox" name="option2" value="Option 2"> 直链藻属
</label>
<label>
<input type="checkbox" name="option3" value="Option 3"> 菱形藻属
</label>
</div> -->
<div class="title center" >源文件上传</div>
<div class="content-body">
<div class="distribute-textBox">
<div style="display: flex;justify-content: center; align-items: center; ">
<!-- <img src="./assets/image/gdut.jpg" alt="" > -->
<div class="imagebox" >
<el-upload
class="avatar-uploader"
action= http://192.168.161.130:80/api/upload
:show-file-list="false"
:on-success="handleAvatarSuccess">
<img v-if="imageUrl" :src="imageUrl" class="avatar">
<i v-else class="el-icon-plus avatar-uploader-icon"></i>
<el-button size="small" class="btn" type="btn-primary">点击上传</el-button>
</el-upload>
</div>
</div>
<div class="table">
<div class="title center" >检测结果</div>
<thead>
<tr>
<td align="center" style="width:200px;">类别</td>
<td align="center" style="width:200px;">概率</td>
</tr>
</thead>
<tbody>
<tr v-for="item in tableData" :key="item">
<td align="center">{{item.class}}</td>
<td align="center">{{item.score}}</td>
</tr>
</tbody>
</div>
</div>
</div>
</section>
</div>
<div class="content-middle">
<div class="image-box">
<img :src="resultUrl" id="100"
alt="">
</div>
<div class="btnArea">
<button type="btn-primary"
@click="deleteImg"
class="btn">删除</button>
<button type="btn-primary"
@click="getResult"
class="btn">获取结果</button>
</div>
</div>
<div class="content-right">
<!-- <div class="section-box eventStatistics-box">
<img src="./assets/image/logo.png"
alt="">
</div> -->
<div class="title">我国主要水藻图例</div>
<div class="content">
<div class="row">
<div class="column">
<img src="./assets/image/a.png" alt="">
<br>小环藻属
</div>
<div class="column">
<img src="./assets/image/b.png" alt="">
<br>直链藻属
</div>
</div>
<div class="row">
<div class="column">
<img src="./assets/image/c.png" alt="">
<br>菱形藻属
</div>
<div class="column">
<img src="./assets/image/d.png" alt="">
<br>卵形藻属
</div>
</div>
<div class="row">
<div class="column">
<img src="./assets/image/e.png" alt="">
<br>桥弯藻属
</div>
<div class="column">
<img src="./assets/image/f.png" alt="">
<br>异形藻属
</div>
</div>
<div class="row">
<div class="column">
<img src="./assets/image/g.png" alt="">
<br>针杆藻属
</div>
<div class="column">
<img src="./assets/image/h.png" alt="">
<br>舟形藻属
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</template>
<script>
import dayjs from "dayjs";
import {
getWeather
} from "@/api/bigScreen-api";
import nodata from"./assets/image/no-data.png";
import axios from 'axios'
export default {
name: "App",
components: {},
data () {
return {
timer: null,
// baseUrL:'http://192.168.161.130/',
baseUrL:'http://192.168.161.130:80/',
nowTime: dayjs(new Date()).format("YYYY-MM-DD HH:mm"),
weatherObj: {},
resultClass:"",
resultScore:"",
imageUrl: 'http://192.168.161.130:80/assets/gdut.jpg',
processIdNow:1,
resultUrl:nodata,
// tableData: [{
// class :'异极藻(Gomphonema',
// score :'0.8269811',
// },
// {
// class :'异极藻(Gomphonema',
// score :'0.8269811',
// },
// {
// class :'舟形藻(Navicula',
// score :'0.54520714',
// },
// ]
tableData:[]
};
},
mounted () {
this.init();
},
methods: {
init () {
this.getNowTime(); // 获取当前时间
this.getWeather(); // 获取天气
},
// 获取当前时间
getNowTime () {
clearInterval(this.timer);
this.timer = setInterval(() => {
this.nowTime = dayjs(new Date()).format("YYYY-MM-DD HH:mm");
}, 1000 * 30);
},
// 获取天气
getWeather () {
getWeather().then((res) => {
if (res.code == 200) {
this.weatherObj = res.data;
}
});
},
handleAvatarSuccess(res, file) {
this.imageUrl = URL.createObjectURL(file.raw);
//获取图片id
this.processIdNow = res.data;
},
deleteImg () {
this.$confirm('此操作将永久删除该文件, 是否继续?', '提示', {
confirmButtonText: '确定',
cancelButtonText: '取消',
type: 'warning'
}).then(() => {
axios.delete(this.baseUrL+`api/delete/${this.processIdNow}`).then((res)=>{
if(res.data.code==0){
this.$message.success('删除成功')
}
else{
this.$message.error('删除失败')
}
})
.catch(() => {
this.$message.error('删除失败')
});
}).catch(() => {
this.$message({
type: 'info',
message: '已取消删除'
});
});
},
getResult(){
axios.get(this.baseUrL+`api/predict/${this.processIdNow}`).then((res)=>{
if(res.data.code ==0){
this.response_=res.data.data;
console.log(res.data);
console.log(this.response_[0])
document.getElementById('100').src = this.response_[0];
var result_class_degree = JSON.parse(this.response_[1]);
// console.log(result_class_degree)
// this.resultClass=result_class_degree[0].class;
// this.resultScore=result_class_degree[0].score;
this.tableData=[]
for (var i = 0; i < result_class_degree.length; i++) {
var rowData = result_class_degree[i];
this.tableData.push({
class: rowData.class,
score: rowData.score
})
}
}
else{
this.$message.error('失败')
}
}).catch(() => {
this.$message.error('文件已删除或未上传文件')
});
},
deactivated () {
clearInterval(this.timer);
},
beforeDestroy () {
clearInterval(this.timer);
},
}
};
</script>
<style lang="less" scoped>
#app {
height: 100%;
}
/* 布局相关 start */
.container {
height: 100%;
background-image: url("~@/assets/image/background.png");
background-size: cover;
}
.content-title {
height: 2rem;
display: flex;
align-items: center;
font-size: 1.2rem;
font-weight: bolder;
color: #0166e2;
img {
margin-right: 5px;
}
}
.content-body {
height: calc(100% - 12.2rem);
margin-top: 1rem;
display: flex;
flex-direction: column;
justify-content: space-between;
}
.content-bodyhead{
text-align:left;
border:solid rgb(118, 103, 118);
padding: 10px;
color: #8c939d;
}
.content-bodyhead text{
padding: 4px 0;
}
.section-box {
box-sizing: border-box;
padding: 1rem;
border: 2px solid #ffffff;
}
.content-bottom {
display: flex;
flex-wrap: nowrap;
position: relative;
height: calc(100% - 7rem);
margin-top: -1rem;
padding: 0 1rem 1rem 1rem;
}
.content-bottom .title {
font-size: 24px;
font-weight: bold;
background-color: #2e325a;
color: #c4c6c9;
bottom: 10px;
}
.center{
display: flex;
align-content: center;
justify-content: center;
}
.content-left {
width: 27%;
margin-right: 1rem;
}
.content-middle {
width: 46%;
margin-right: 1rem;
display: flex;
flex-direction: column;
justify-content: space-between;
}
.content-right {
width: 27%;
text-align: center;
// border-style:solid;
border-color:#cfc5c5;
background: rgba(255, 255, 255, 0.2); /* 使用半透明白色作为背景颜色 */
backdrop-filter: blur(10px); /* 使用模糊滤镜来创建磨砂玻璃效果 */
}
//右边盒子的布局
.content-right .title {
font-size: 24px;
font-weight: bold;
background-color: #2e325a;
color: #c4c6c9;
bottom: 10px;
}
.content {
display: flex;
flex-wrap: wrap;
justify-content: center;
}
.row {
display: flex;
flex-wrap: nowrap;
margin: 10px 0 0 0;
}
.column {
flex: 1;
border: 2px solid #ccc;
padding: 10px;
margin: 0 5px;
overflow: hidden;
color:#c4c6c9;
background-color: #16388e;
// max-height: 100%; /* 限制行的最大高度 */
}
.column img {
max-width: 66%; /* 图片的最大宽度为列的宽度 */
height: auto; /* 保持图片的纵横比 */
}
/* 布局相关 end */
/* header start */
.content-top {
position: relative;
height: 7rem;
text-align: center;
font-style: italic;
font-weight: bolder;
background: url("@/assets/image/background-top.jpg") no-repeat center;
background-size: 100% 100%;
background-position-y: -0.6875rem;
color: rgb(194, 193, 194);
.title {
line-height: 5rem;
font-size: 2rem;
letter-spacing: 2px;
}
.time-box {
position: absolute;
left: 5rem;
top: 0.5rem;
display: flex;
height: 30%;
font-size: 1.0625rem;
}
#time,
#weather {
padding: 0.5rem 1.2rem 0 1.2rem;
border-top: 2px solid #0166e2;
letter-spacing: 1px;
}
.weather-box {
position: absolute;
right: 5rem;
top: 0.5rem;
display: flex;
height: 30%;
font-size: 1.2rem;
}
.line-box {
width: 3rem;
height: 30%;
border-top: 2px solid #0166e2;
}
.line-box1 {
margin-left: 1rem;
}
.line-box2 {
margin-right: 1rem;
}
}
/* header end */
.roadDistribute-box {
height: 100%;
.distribute-textBox {
height: 40%;
padding: 1rem;
line-height: 3.2rem;
letter-spacing: 1px;
font-size: 1.5rem;
color: #ffffff;
}
}
.imagebox {
display: flex;
justify-content: center; //弹性盒子对象在主轴上的对齐方式
align-items: center; //定义flex子项在flex容器的当前行的侧轴(纵轴)方向上的对齐方式。
box-sizing: border-box;
border-color: #ffffff;
background: url("@/assets/image/background-map.png") no-repeat center;
position:relative
}
.imagebox img {
width: 300px;
height: 300px;
border: 1px solid #fff;
background-color: #fff;
}
.image-box {
display: flex;
justify-content: center; //弹性盒子对象在主轴上的对齐方式
align-items: center; //定义flex子项在flex容器的当前行的侧轴(纵轴)方向上的对齐方式。
box-sizing: border-box;
background-size: 100% 100%;
background: url("@/assets/image/background-map.png") no-repeat center;
}
.image-box img {
width: 87%;
height: 87%;
background-color: #fff;
}
.eventStatistics-box {
height: 100%;
display: flex;
flex-direction: column;
justify-content: space-between;
}
.eventStatistics-box img {
width: 100%;
height: 50%;
background-color: #fff;
}
.btnArea {
display: flex;
flex-direction: row;
justify-content: flex-start;
align-items: center;
height: 10%;
}
.btn {
background-image: linear-gradient(to bottom right, #196cd8, #3c85c0, #63a6e4);
overflow: hidden;
font-size: 20px;
width: 60%;
height: 80%;
color: #ffffff;
border: 0 solid #fff;
opacity: 0.8;
margin: 0 2px;
border-radius: 20px;
}
.avatar-uploader .el-upload {
border: 1px dashed #d9d9d9;
border-radius: 6px;
cursor: pointer;
position: relative;
overflow: hidden;
}
.avatar-uploader .el-upload:hover {
border-color: #409EFF;
}
.avatar-uploader-icon {
font-size: 28px;
color: #8c939d;
width: 178px;
height: 178px;
line-height: 178px;
text-align: center;
}
.avatar {
width: 178px;
height: 178px;
display: block;
}
.above_txt{
position:absolute;top:108px;left:108px;
}
.table{
margin-top :2rem;
line-height: 2rem;
letter-spacing: 0px;
font-size:1.3rem;
}
.table .title{
margin-top :2rem;
letter-spacing: 0px;
font-size:1.5rem;
margin-bottom: 1rem;
}
</style>
+10
View File
@@ -0,0 +1,10 @@
import request from '@/utils/request'
// 获取天气情况
export function getWeather (param) {
return request({
url: '/getWeather',
method: 'get',
params: param
})
}
Binary file not shown.

After

Width:  |  Height:  |  Size: 126 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 128 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 174 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 149 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 145 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 159 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 146 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 436 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 118 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

+14
View File
@@ -0,0 +1,14 @@
import Vue from 'vue'
import App from './App.vue'
Vue.config.productionTip = false
import ElementUI from 'element-ui'
import 'element-ui/lib/theme-chalk/index.css'
Vue.use(ElementUI)
import '../mock/index'
new Vue({
render: h => h(App),
}).$mount('#app')
+89
View File
@@ -0,0 +1,89 @@
import axios from 'axios'
import { Notification } from 'element-ui'
// 创建axios实例
const service = axios.create({
baseURL: '/api',
timeout: 80000, // 请求超时时间
withCredentials: true,
// crossDomain: true
})
// request拦截器
service.interceptors.request.use(
config => {
// if (getToken()) {
// config.headers['Authorization'] = getToken() // 让每个请求携带自定义token 请根据实际情况自行修改
// }
var lang = localStorage.getItem('lang')//因为项目中使用到了i18n国际化语言配置,请根据实际情况自行修改
if (!lang) {
lang = 'zh_CN'
}
config.headers['Accept-Language'] = lang.replace(/_/g, '-')
config.headers['Content-Type'] = 'application/json'
return config
},
error => {
Promise.reject(error)
}
)
// response 拦截器
service.interceptors.response.use(
response => {
return response.data
},
error => {
// 兼容blob下载出错json提示
if (error.response.data instanceof Blob && error.response.data.type.toLowerCase().indexOf('json') !== -1) {
const reader = new FileReader()
reader.readAsText(error.response.data, 'utf-8')
reader.onload = function () {
const errorMsg = JSON.parse(reader.result).message
Notification.error({
title: errorMsg,
duration: 5000
})
}
} else {
let code = 0
try {
code = error.response.data.status
} catch (e) {
if (error.toString().indexOf('Error: timeout') !== -1) {
Notification.error({
title: '网络请求超时',
duration: 5000
})
return Promise.reject(error)
}
}
if (code) {
// if (code === 401) {
// store.dispatch('LogOut').then(() => {
// // 用户登录界面提示
// Cookies.set('point', 401)
// location.reload()
// })
// } else if (code === 403) {
// router.push({ path: '/401' })
// } else {
// const errorMsg = error.response.data.message
// if (errorMsg !== undefined) {
// Notification.error({
// title: errorMsg,
// duration: 0
// })
// }
// }
} else {
Notification.error({
title: '接口请求失败',
duration: 5000
})
}
}
return Promise.reject(error)
}
)
export default service
+16
View File
@@ -0,0 +1,16 @@
const { defineConfig } = require('@vue/cli-service')
module.exports = defineConfig({
transpileDependencies: true,
devServer: {
open: true, // 运行后自动打开浏览器
port: 8080,
// proxy: {
// '/api': {
// target: 'http://192.168.1.102:8088',
// ws: false,
// changeOrigin: true,
// logLevel: 'debug'
// }
// }
},
})
+140
View File
@@ -0,0 +1,140 @@
# ignore map, miou, datasets
map_out/
miou_out/
VOCdevkit/
datasets/
Medical_Datasets/
lfw/
logs/
model_data/
.temp_map_out/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
File diff suppressed because it is too large Load Diff
+209
View File
@@ -0,0 +1,209 @@
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001406.jpg 243,20,271,419,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001872.jpg 72,116,437,352,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000618.jpg 70,168,421,254,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000583.jpg 200,96,310,396,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000784.jpg 141,104,380,360,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000210.jpg 119,106,383,370,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000915.jpg 222,50,294,422,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000004.jpg 174,142,327,295,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001580.jpg 130,103,408,361,4
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002486.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002287.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000883.jpg 111,150,402,285,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001285.jpg 196,70,326,416,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000640.jpg 125,27,378,412,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001855.jpg 201,109,294,343,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001989.jpg 128,165,406,293,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001300.jpg 154,6,326,410,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000798.jpg 90,99,410,281,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001990.jpg 157,90,329,305,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002152.jpg 186,123,328,339,6
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002407.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000291.jpg 189,160,317,290,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002381.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000657.jpg 211,143,290,311,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002015.jpg 127,116,419,244,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001517.jpg 110,114,411,298,4
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000383.jpg 136,126,342,333,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000154.jpg 141,112,367,334,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000480.jpg 194,68,316,406,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000846.jpg 230,58,306,391,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001419.jpg 207,52,305,389,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002539.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001318.jpg 70,23,382,437,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001805.jpg 74,154,416,292,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001021.jpg 53,145,443,228,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000484.jpg 109,149,414,240,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002356.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001985.jpg 194,105,359,329,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002352.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001148.jpg 89,89,413,354,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001066.jpg 152,133,381,319,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001489.jpg 206,94,307,353,4
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001852.jpg 60,144,427,257,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000762.jpg 72,195,449,273,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000709.jpg 154,111,358,322,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001017.jpg 209,51,286,422,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001785.jpg 110,185,404,310,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002260.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001718.jpg 1,122,475,248,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002128.jpg 130,117,392,282,6
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000898.jpg 47,97,483,371,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000503.jpg 115,54,350,382,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000961.jpg 33,162,481,253,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001691.jpg 153,57,352,348,4
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000337.jpg 183,159,311,284,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002061.jpg 116,109,396,312,6
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002143.jpg 117,63,398,320,6
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000735.jpg 155,56,334,371,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001213.jpg 150,45,354,392,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001251.jpg 229,80,284,386,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002163.jpg 110,137,417,330,6
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002448.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000086.jpg 138,123,352,332,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002581.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001765.jpg 186,148,332,302,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001759.jpg 109,22,342,374,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001772.jpg 166,15,340,388,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001910.jpg 188,128,286,382,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002132.jpg 100,98,453,371,6
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001648.jpg 99,162,361,257,4
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001762.jpg 223,114,318,340,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002157.jpg 204,124,314,332,6
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001701.jpg 136,71,360,420,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001666.jpg 135,91,370,318,4
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001850.jpg 53,158,452,307,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000122.jpg 191,167,305,278,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001741.jpg 172,65,355,414,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001519.jpg 60,131,433,324,4
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001252.jpg 71,32,452,415,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002371.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000936.jpg 129,59,384,410,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002173.jpg 130,135,403,301,6
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001579.jpg 25,54,404,421,4
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002307.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001681.jpg 115,108,379,277,4
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000547.jpg 135,169,343,272,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002215.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000861.jpg 104,62,370,424,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002398.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000893.jpg 217,41,295,420,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002064.jpg 143,113,365,351,6
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002522.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000856.jpg 107,155,383,234,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000228.jpg 176,136,356,315,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000633.jpg 109,117,416,316,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000022.jpg 155,105,341,290,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001088.jpg 80,145,431,305,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002565.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002276.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000308.jpg 90,68,385,359,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000828.jpg 218,51,309,354,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001180.jpg 87,138,374,342,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000264.jpg 118,128,317,329,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000782.jpg 191,5,318,396,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000157.jpg 103,79,402,380,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001228.jpg 181,62,333,388,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002034.jpg 144,86,329,412,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001423.jpg 147,89,382,378,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000060.jpg 117,118,393,393,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001720.jpg 51,159,434,298,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001024.jpg 26,19,460,422,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001766.jpg 133,31,370,399,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000359.jpg 109,66,369,329,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000943.jpg 114,193,395,248,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000216.jpg 141,106,393,363,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001768.jpg 172,1,330,396,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002488.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002186.jpg 154,87,351,397,6
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000469.jpg 67,152,426,307,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002406.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000506.jpg 120,160,417,297,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002077.jpg 114,107,380,324,6
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000093.jpg 137,129,358,347,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002239.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001714.jpg 187,51,314,393,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000849.jpg 119,112,394,309,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001811.jpg 190,79,381,364,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001173.jpg 110,99,401,364,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000047.jpg 159,115,365,318,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001722.jpg 150,116,378,300,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001605.jpg 160,35,352,355,4
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002266.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000971.jpg 73,180,419,218,2 76,204,418,246,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002177.jpg 120,63,391,374,6
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001166.jpg 191,37,329,444,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002304.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000962.jpg 107,29,397,381,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001005.jpg 43,122,466,347,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001556.jpg 180,127,328,304,4
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001776.jpg 171,147,316,347,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002370.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001903.jpg 206,60,284,402,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000651.jpg 129,137,437,323,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001212.jpg 70,77,445,339,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002274.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002102.jpg 88,97,407,333,6
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000705.jpg 135,130,382,325,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002455.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000776.jpg 76,85,413,310,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002261.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001132.jpg 32,98,440,280,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000582.jpg 180,103,293,351,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001026.jpg 125,177,387,237,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002011.jpg 126,78,391,437,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001115.jpg 212,44,293,403,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001839.jpg 70,167,438,282,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001122.jpg 95,30,405,442,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002063.jpg 143,88,357,374,6
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001283.jpg 43,117,420,359,3 212,132,334,443,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001725.jpg 210,139,287,349,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000973.jpg 152,7,373,426,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000237.jpg 115,108,380,374,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000942.jpg 55,178,443,274,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001745.jpg 87,118,425,310,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002360.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001790.jpg 144,112,405,366,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002312.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000857.jpg 137,205,358,264,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001594.jpg 211,66,300,399,4
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000663.jpg 176,123,346,336,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001964.jpg 109,165,392,302,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001395.jpg 76,66,408,378,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000843.jpg 71,111,418,400,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000371.jpg 133,119,370,352,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000795.jpg 225,102,311,386,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001211.jpg 107,21,420,407,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000835.jpg 118,108,343,320,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001366.jpg 85,81,433,376,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002242.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000677.jpg 86,112,421,341,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002289.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000381.jpg 135,117,360,334,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000722.jpg 154,180,382,267,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002336.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000350.jpg 141,72,415,343,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002311.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002354.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000485.jpg 221,137,291,331,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000130.jpg 118,73,369,321,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001465.jpg 16,90,459,390,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000964.jpg 208,30,318,402,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000457.jpg 220,75,300,328,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002018.jpg 138,47,360,438,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000768.jpg 104,88,400,381,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000479.jpg 187,82,340,396,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000811.jpg 8,133,479,332,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000834.jpg 31,39,489,388,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000007.jpg 138,112,373,343,0
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002146.jpg 189,123,323,341,6
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001070.jpg 113,123,406,333,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002579.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000787.jpg 166,45,325,429,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000594.jpg 122,63,393,434,1
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002562.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001371.jpg 158,93,323,397,3
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000920.jpg 68,90,416,366,2
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002458.jpg
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001826.jpg 94,149,386,384,5 26,37,76,152,5
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002434.jpg
+201
View File
@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright Megvii, Base Detection
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
+180
View File
@@ -0,0 +1,180 @@
## YOLOXYou Only Look Once目标检测模型在Pytorch当中的实现
---
## 目录
1. [仓库更新 Top News](#仓库更新)
2. [相关仓库 Related code](#相关仓库)
3. [性能情况 Performance](#性能情况)
4. [实现的内容 Achievement](#实现的内容)
5. [所需环境 Environment](#所需环境)
6. [文件下载 Download](#文件下载)
7. [训练步骤 How2train](#训练步骤)
8. [预测步骤 How2predict](#预测步骤)
9. [评估步骤 How2eval](#评估步骤)
10. [参考资料 Reference](#Reference)
## Top News
**`2022-04`**:**支持多GPU训练,新增各个种类目标数量计算,新增heatmap。**
**`2022-03`**:**进行了大幅度的更新,支持step、cos学习率下降法、支持adam、sgd优化器选择、支持学习率根据batch_size自适应调整、新增图片裁剪。**
BiliBili视频中的原仓库地址为:https://github.com/bubbliiiing/yolox-pytorch/tree/bilibili
**`2021-10`**:**创建仓库,支持不同尺寸模型训练、支持大量可调整参数,支持fps、视频预测、批量预测等功能。**  
## 相关仓库
| 模型 | 路径 |
| :----- | :----- |
YoloV3 | https://github.com/bubbliiiing/yolo3-pytorch
Efficientnet-Yolo3 | https://github.com/bubbliiiing/efficientnet-yolo3-pytorch
YoloV4 | https://github.com/bubbliiiing/yolov4-pytorch
YoloV4-tiny | https://github.com/bubbliiiing/yolov4-tiny-pytorch
Mobilenet-Yolov4 | https://github.com/bubbliiiing/mobilenet-yolov4-pytorch
YoloV5-V5.0 | https://github.com/bubbliiiing/yolov5-pytorch
YoloV5-V6.1 | https://github.com/bubbliiiing/yolov5-v6.1-pytorch
YoloX | https://github.com/bubbliiiing/yolox-pytorch
YoloV7 | https://github.com/bubbliiiing/yolov7-pytorch
YoloV7-tiny | https://github.com/bubbliiiing/yolov7-tiny-pytorch
## 性能情况
| 训练数据集 | 权值文件名称 | 测试数据集 | 输入图片大小 | mAP 0.5:0.95 | mAP 0.5 |
| :-----: | :-----: | :------: | :------: | :------: | :-----: |
| COCO-Train2017 | [yolox_nano.pth](https://github.com/bubbliiiing/yolox-pytorch/releases/download/v1.0/yolox_nano.pth) | COCO-Val2017 | 640x640 | 27.4 | 44.5
| COCO-Train2017 | [yolox_tiny.pth](https://github.com/bubbliiiing/yolox-pytorch/releases/download/v1.0/yolox_tiny.pth) | COCO-Val2017 | 640x640 | 34.7 | 53.6
| COCO-Train2017 | [yolox_s.pth](https://github.com/bubbliiiing/yolox-pytorch/releases/download/v1.0/yolox_s.pth) | COCO-Val2017 | 640x640 | 38.2 | 57.7
| COCO-Train2017 | [yolox_m.pth](https://github.com/bubbliiiing/yolox-pytorch/releases/download/v1.0/yolox_m.pth) | COCO-Val2017 | 640x640 | 44.8 | 63.9
| COCO-Train2017 | [yolox_l.pth](https://github.com/bubbliiiing/yolox-pytorch/releases/download/v1.0/yolox_l.pth) | COCO-Val2017 | 640x640 | 47.9 | 66.6
| COCO-Train2017 | [yolox_x.pth](https://github.com/bubbliiiing/yolox-pytorch/releases/download/v1.0/yolox_x.pth) | COCO-Val2017 | 640x640 | 49.0 | 67.7
## 实现的内容
- [x] 主干特征提取网络:使用了Focus网络结构。
- [x] 分类回归层:Decoupled Head,在YoloX中,Yolo Head被分为了分类回归两部分,最后预测的时候才整合在一起。
- [x] 训练用到的小技巧:Mosaic数据增强、IOU和GIOU、学习率余弦退火衰减。
- [x] Anchor Free:不使用先验框
- [x] SimOTA:为不同大小的目标动态匹配正样本。
## 所需环境
pytorch==1.2.0
## 文件下载
训练所需的权值可在百度网盘中下载。
链接: https://pan.baidu.com/s/1bi2UBwwIHES0OpLeyYuBFg
提取码: f4ni
VOC数据集下载地址如下,里面已经包括了训练集、测试集、验证集(与测试集一样),无需再次划分:
链接: https://pan.baidu.com/s/19Mw2u_df_nBzsC2lg20fQA
提取码: j5ge
## 训练步骤
### a、训练VOC07+12数据集
1. 数据集的准备
**本文使用VOC格式进行训练,训练前需要下载好VOC07+12的数据集,解压后放在根目录**
2. 数据集的处理
修改voc_annotation.py里面的annotation_mode=2,运行voc_annotation.py生成根目录下的2007_train.txt和2007_val.txt。
3. 开始网络训练
train.py的默认参数用于训练VOC数据集,直接运行train.py即可开始训练。
4. 训练结果预测
训练结果预测需要用到两个文件,分别是yolo.py和predict.py。我们首先需要去yolo.py里面修改model_path以及classes_path,这两个参数必须要修改。
**model_path指向训练好的权值文件,在logs文件夹里。
classes_path指向检测类别所对应的txt。**
完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。
### b、训练自己的数据集
1. 数据集的准备
**本文使用VOC格式进行训练,训练前需要自己制作好数据集,**
训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。
训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。
2. 数据集的处理
在完成数据集的摆放之后,我们需要利用voc_annotation.py获得训练用的2007_train.txt和2007_val.txt。
修改voc_annotation.py里面的参数。第一次训练可以仅修改classes_pathclasses_path用于指向检测类别所对应的txt。
训练自己的数据集时,可以自己建立一个cls_classes.txt,里面写自己所需要区分的类别。
model_data/cls_classes.txt文件内容为:
```python
cat
dog
...
```
修改voc_annotation.py中的classes_path,使其对应cls_classes.txt,并运行voc_annotation.py。
3. 开始网络训练
**训练的参数较多,均在train.py中,大家可以在下载库后仔细看注释,其中最重要的部分依然是train.py里的classes_path。**
**classes_path用于指向检测类别所对应的txt,这个txt和voc_annotation.py里面的txt一样!训练自己的数据集必须要修改!**
修改完classes_path后就可以运行train.py开始训练了,在训练多个epoch后,权值会生成在logs文件夹中。
4. 训练结果预测
训练结果预测需要用到两个文件,分别是yolo.py和predict.py。在yolo.py里面修改model_path以及classes_path。
**model_path指向训练好的权值文件,在logs文件夹里。
classes_path指向检测类别所对应的txt。**
完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。
## 预测步骤
### a、使用预训练权重
1. 下载完库后解压,在百度网盘下载yolo_weights.pth,放入model_data,运行predict.py,输入
```python
img/street.jpg
```
2. 在predict.py里面进行设置可以进行fps测试和video视频检测。
### b、使用自己训练的权重
1. 按照训练步骤训练。
2. 在yolo.py文件里面,在如下部分修改model_path和classes_path使其对应训练好的文件;**model_path对应logs文件夹下面的权值文件,classes_path是model_path对应分的类**。
```python
_defaults = {
#--------------------------------------------------------------------------#
# 使用自己训练好的模型进行预测一定要修改model_path和classes_path
# model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
# 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
#--------------------------------------------------------------------------#
"model_path" : 'model_data/yolox_s.pth',
"classes_path" : 'model_data/coco_classes.txt',
#---------------------------------------------------------------------#
# 输入图片的大小,必须为32的倍数。
#---------------------------------------------------------------------#
"input_shape" : [640, 640],
#---------------------------------------------------------------------#
# 所使用的YoloX的版本。nano、tiny、s、m、l、x
#---------------------------------------------------------------------#
"phi" : 's',
#---------------------------------------------------------------------#
# 只有得分大于置信度的预测框会被保留下来
#---------------------------------------------------------------------#
"confidence" : 0.5,
#---------------------------------------------------------------------#
# 非极大抑制所用到的nms_iou大小
#---------------------------------------------------------------------#
"nms_iou" : 0.3,
#---------------------------------------------------------------------#
# 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
# 在多次测试后,发现关闭letterbox_image直接resize的效果更好
#---------------------------------------------------------------------#
"letterbox_image" : True,
#-------------------------------#
# 是否使用Cuda
# 没有GPU可以设置成False
#-------------------------------#
"cuda" : True,
}
```
3. 运行predict.py,输入
```python
img/street.jpg
```
4. 在predict.py里面进行设置可以进行fps测试和video视频检测。
## 评估步骤
### a、评估VOC07+12的测试集
1. 本文使用VOC格式进行评估。VOC07+12已经划分好了测试集,无需利用voc_annotation.py生成ImageSets文件夹下的txt。
2. 在yolo.py里面修改model_path以及classes_path。**model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。**
3. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。
### b、评估自己的数据集
1. 本文使用VOC格式进行评估。
2. 如果在训练前已经运行过voc_annotation.py文件,代码会自动将数据集划分成训练集、验证集和测试集。如果想要修改测试集的比例,可以修改voc_annotation.py文件下的trainval_percent。trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1。train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1。
3. 利用voc_annotation.py划分测试集后,前往get_map.py文件修改classes_pathclasses_path用于指向检测类别所对应的txt,这个txt和训练时的txt一样。评估自己的数据集必须要修改。
4. 在yolo.py里面修改model_path以及classes_path。**model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。**
5. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。
## Reference
https://github.com/Megvii-BaseDetection/YOLOX
+17
View File
@@ -0,0 +1,17 @@
<!--
* @Author: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
* @Date: 2023-03-31 20:04:39
* @LastEditors: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
* @LastEditTime: 2023-03-31 20:04:47
* @FilePath: /luyuetong-data/yolox-pytorch/display_image.html
* @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
-->
<!DOCTYPE html>
<html>
<head>
<title>Display Image</title>
</head>
<body>
<img src="{{ url_for('upload_image') }}" alt="Uploaded Image">
</body>
</html>
+95
View File
@@ -0,0 +1,95 @@
'''
Author: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
Date: 2023-03-29 20:22:34
LastEditors: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
LastEditTime: 2023-04-05 15:46:07
FilePath: /luyuetong-data/yolox-pytorch/flask_on.py
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://githaub.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
'''
from flask import Flask, make_response, request, jsonify,send_file
from flask_cors import CORS
from werkzeug.utils import secure_filename
from PIL import Image
from yolo import YOLO
from predict import PREDICT
import os
import quanju
import requests
import urllib.request
from io import BytesIO
import base64
import io
app = Flask(__name__)
CORS(app)
@app.route('/uploads', methods=['GET','POST'])
def uploads():
global filename
# 检查是否收到了POST请求
if request.method == 'POST':
# 检查是否收到了文件
if 'file' not in request.files:
return jsonify({'error': 'No file uploaded.'}), 400
file = request.files['file']
filename = file.read() #此时转换出来的是二进制内容
file_path = filename.decode("utf-8")
print(file_path)
print(f"file:{file}")
# 将blob转换为base64编码的字符串
image_data = base64.b64encode(file.read()).decode('utf-8')
# 将base64编码的字符串解码为图片数据
img = Image.open(io.BytesIO(base64.b64decode(image_data)))
Image.open(img)
# 将bytes类型传递给预测函数进行处理
result = PREDICT(img)
#
# print(file)
# filename = secure_filename(file.filename)
# file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))
# result = PREDICT()
#检查是否选择了文件
if file.filename == '':
return jsonify({'error': 'No file selected.'}), 400
# # 检查文件类型是否合法
# # if file and allowed_file(file.filename):
# filename = secure_filename(file.filename)
#保存文件到本地
# file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))
#自定义一个全局变量承接名字(自定义库.变量)
# quanju.mingzi = filename
return "ok"
# 将文件路径传递给模型进行预测
# result = PREDICT(os.path.join(app.config['UPLOAD_FOLDER'], filename))
# return send_file(f'./results/result_{quanju.mingzi}', mimetype='image/jpeg')
# return send_file('results/result_{}'.format(quanju.mingzi), mimetype='image/jpeg')
# return jsonify({'error': 'Invalid file type.'}), 400
def allowed_file(filename):
# 检查文件类型是否合法,这里仅接受jpeg、jpg、png类型的图片文件
return '.' in filename and filename.rsplit('.', 1)[1].lower() in {'jpeg', 'jpg', 'png'}
if __name__ == '__main__':
# app.config['UPLOAD_FOLDER'] = './uploads'
app.config['UPLOAD_FOLDER'] = 'yolox-pytorch/uploads'
# app.debug = False
app.run(host='10.21.19.104', port=8000)
+42
View File
@@ -0,0 +1,42 @@
'''
Author: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
Date: 2023-03-29 20:22:34
LastEditors: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
LastEditTime: 2023-04-05 15:27:13
FilePath: /luyuetong-data/yolox-pytorch/flask_on.py
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://githaub.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
'''
import os
from flask import Flask, request, jsonify
from gevent import pywsgi
from predict import PREDICT
app = Flask(__name__)
# 设置未处理文件路径以及处理保存路径
# dataPath = '/home/DiatomRecognition/upload'
# resultPath = '/home/DiatomRecognition/result'
# localhost 测试
dataPath = 'E:/Document/project/temp/data'
resultPath = r'E:\Document\project\temp/save'
@app.route('/predict', methods=['POST'])
def predict_one():
if not os.path.exists(resultPath):
# 如果不存在,创建文件夹及其所有父文件夹
os.makedirs(resultPath)
# 从请求参数中提取文件名
fileName = request.form['fileName']
# 将文件路径传递给模型进行预测
result = PREDICT(dataPath, resultPath, fileName) # predicted_classscore如果需要再处理
return result
if __name__ == '__main__':
# app.debug = False
server = pywsgi.WSGIServer(('0.0.0.0', 2000), app)
app.config["JSON_AS_ASCII"] = False
server.serve_forever()
+138
View File
@@ -0,0 +1,138 @@
import os
import xml.etree.ElementTree as ET
from PIL import Image
from tqdm import tqdm
from utils.utils import get_classes
from utils.utils_map import get_coco_map, get_map
from yolo import YOLO
if __name__ == "__main__":
'''
Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。
默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。
受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值
因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框,
'''
#------------------------------------------------------------------------------------------------------------------#
# map_mode用于指定该文件运行时计算的内容
# map_mode为0代表整个map计算流程,包括获得预测结果、获得真实框、计算VOC_map。
# map_mode为1代表仅仅获得预测结果。
# map_mode为2代表仅仅获得真实框。
# map_mode为3代表仅仅计算VOC_map。
# map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行
#-------------------------------------------------------------------------------------------------------------------#
map_mode = 0
#--------------------------------------------------------------------------------------#
# 此处的classes_path用于指定需要测量VOC_map的类别
# 一般情况下与训练和预测所用的classes_path一致即可
#--------------------------------------------------------------------------------------#
classes_path = 'model_data/voc_classes.txt' #..
#--------------------------------------------------------------------------------------#
# MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。
# 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。
#
# 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。
# 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低,
#--------------------------------------------------------------------------------------#
MINOVERLAP = 0.5
#--------------------------------------------------------------------------------------#
# 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP
# 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。
#
# 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。
# 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。
#--------------------------------------------------------------------------------------#
confidence = 0.001
#--------------------------------------------------------------------------------------#
# 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。
#
# 该值一般不调整。
#--------------------------------------------------------------------------------------#
nms_iou = 0.5
#---------------------------------------------------------------------------------------------------------------#
# Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。
#
# 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。
# 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。
# 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。
#---------------------------------------------------------------------------------------------------------------#
score_threhold = 0.5
#-------------------------------------------------------#
# map_vis用于指定是否开启VOC_map计算的可视化
#-------------------------------------------------------#
map_vis = False
#-------------------------------------------------------#
# 指向VOC数据集所在的文件夹
# 默认指向根目录下的VOC数据集
#-------------------------------------------------------#
VOCdevkit_path = 'VOCdevkit'
#-------------------------------------------------------#
# 结果输出的文件夹,默认为map_out
#-------------------------------------------------------#
map_out_path = 'map_out'
image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split()
if not os.path.exists(map_out_path):
os.makedirs(map_out_path)
if not os.path.exists(os.path.join(map_out_path, 'ground-truth')):
os.makedirs(os.path.join(map_out_path, 'ground-truth'))
if not os.path.exists(os.path.join(map_out_path, 'detection-results')):
os.makedirs(os.path.join(map_out_path, 'detection-results'))
if not os.path.exists(os.path.join(map_out_path, 'images-optional')):
os.makedirs(os.path.join(map_out_path, 'images-optional'))
class_names, _ = get_classes(classes_path)
if map_mode == 0 or map_mode == 1:
print("Load model.")
yolo = YOLO(confidence = confidence, nms_iou = nms_iou)
print("Load model done.")
print("Get predict result.")
for image_id in tqdm(image_ids):
image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg")
image = Image.open(image_path)
if map_vis:
image.save(os.path.join(map_out_path, "images-optional/" + image_id + ".jpg"))
yolo.get_map_txt(image_id, image, class_names, map_out_path)
print("Get predict result done.")
if map_mode == 0 or map_mode == 2:
print("Get ground truth result.")
for image_id in tqdm(image_ids):
with open(os.path.join(map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/"+image_id+".xml")).getroot()
for obj in root.findall('object'):
difficult_flag = False
if obj.find('difficult')!=None:
difficult = obj.find('difficult').text
if int(difficult)==1:
difficult_flag = True
obj_name = obj.find('name').text
if obj_name not in class_names:
continue
bndbox = obj.find('bndbox')
left = bndbox.find('xmin').text
top = bndbox.find('ymin').text
right = bndbox.find('xmax').text
bottom = bndbox.find('ymax').text
if difficult_flag:
new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
else:
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
print("Get ground truth result done.")
if map_mode == 0 or map_mode == 3:
print("Get map.")
get_map(MINOVERLAP, True, score_threhold = score_threhold, path = map_out_path)
print("Get map done.")
if map_mode == 4:
print("Get map.")
get_coco_map(class_names = class_names, path = map_out_path)
print("Get map done.")
Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 67 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 437 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 253 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 222 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 255 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 230 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 253 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 305 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 237 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 230 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 MiB

+1
View File
@@ -0,0 +1 @@
#
+231
View File
@@ -0,0 +1,231 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import torch
from torch import nn
class SiLU(nn.Module):
@staticmethod
def forward(x):
return x * torch.sigmoid(x)
def get_activation(name="silu", inplace=True):
if name == "silu":
module = SiLU()
elif name == "relu":
module = nn.ReLU(inplace=inplace)
elif name == "lrelu":
module = nn.LeakyReLU(0.1, inplace=inplace)
else:
raise AttributeError("Unsupported act type: {}".format(name))
return module
class Focus(nn.Module):
def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
super().__init__()
self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
def forward(self, x):
patch_top_left = x[..., ::2, ::2]
patch_bot_left = x[..., 1::2, ::2]
patch_top_right = x[..., ::2, 1::2]
patch_bot_right = x[..., 1::2, 1::2]
x = torch.cat((patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), dim=1,)
return self.conv(x)
class BaseConv(nn.Module):
def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"):
super().__init__()
pad = (ksize - 1) // 2
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03)
self.act = get_activation(act, inplace=True)
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def fuseforward(self, x):
return self.act(self.conv(x))
class DWConv(nn.Module):
def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
super().__init__()
self.dconv = BaseConv(in_channels, in_channels, ksize=ksize, stride=stride, groups=in_channels, act=act,)
self.pconv = BaseConv(in_channels, out_channels, ksize=1, stride=1, groups=1, act=act)
def forward(self, x):
x = self.dconv(x)
return self.pconv(x)
class SPPBottleneck(nn.Module):
def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"):
super().__init__()
hidden_channels = in_channels // 2
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes])
conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
def forward(self, x):
x = self.conv1(x)
x = torch.cat([x] + [m(x) for m in self.m], dim=1)
x = self.conv2(x)
return x
#--------------------------------------------------#
# 残差结构的构建,小的残差结构
#--------------------------------------------------#
class Bottleneck(nn.Module):
# Standard bottleneck
def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
super().__init__()
hidden_channels = int(out_channels * expansion)
Conv = DWConv if depthwise else BaseConv
#--------------------------------------------------#
# 利用1x1卷积进行通道数的缩减。缩减率一般是50%
#--------------------------------------------------#
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
#--------------------------------------------------#
# 利用3x3卷积进行通道数的拓张。并且完成特征提取
#--------------------------------------------------#
self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
self.use_add = shortcut and in_channels == out_channels
def forward(self, x):
y = self.conv2(self.conv1(x))
if self.use_add:
y = y + x
return y
class CSPLayer(nn.Module):
def __init__(self, in_channels, out_channels, n=1, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
# ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
hidden_channels = int(out_channels * expansion)
#--------------------------------------------------#
# 主干部分的初次卷积
#--------------------------------------------------#
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
#--------------------------------------------------#
# 大的残差边部分的初次卷积
#--------------------------------------------------#
self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
#-----------------------------------------------#
# 对堆叠的结果进行卷积的处理
#-----------------------------------------------#
self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
#--------------------------------------------------#
# 根据循环的次数构建上述Bottleneck残差结构
#--------------------------------------------------#
module_list = [Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n)]
self.m = nn.Sequential(*module_list)
def forward(self, x):
#-------------------------------#
# x_1是主干部分
#-------------------------------#
x_1 = self.conv1(x)
#-------------------------------#
# x_2是大的残差边部分
#-------------------------------#
x_2 = self.conv2(x)
#-----------------------------------------------#
# 主干部分利用残差结构堆叠继续进行特征提取
#-----------------------------------------------#
x_1 = self.m(x_1)
#-----------------------------------------------#
# 主干部分和大的残差边部分进行堆叠
#-----------------------------------------------#
x = torch.cat((x_1, x_2), dim=1)
#-----------------------------------------------#
# 对堆叠的结果进行卷积的处理
#-----------------------------------------------#
return self.conv3(x)
class CSPDarknet(nn.Module):
def __init__(self, dep_mul, wid_mul, out_features=("dark3", "dark4", "dark5"), depthwise=False, act="silu",):
super().__init__()
assert out_features, "please provide output features of Darknet"
self.out_features = out_features
Conv = DWConv if depthwise else BaseConv
#-----------------------------------------------#
# 输入图片是640, 640, 3
# 初始的基本通道是64
#-----------------------------------------------#
base_channels = int(wid_mul * 64) # 64
base_depth = max(round(dep_mul * 3), 1) # 3
#-----------------------------------------------#
# 利用focus网络结构进行特征提取
# 640, 640, 3 -> 320, 320, 12 -> 320, 320, 64
#-----------------------------------------------#
self.stem = Focus(3, base_channels, ksize=3, act=act)
#-----------------------------------------------#
# 完成卷积之后,320, 320, 64 -> 160, 160, 128
# 完成CSPlayer之后,160, 160, 128 -> 160, 160, 128
#-----------------------------------------------#
self.dark2 = nn.Sequential(
Conv(base_channels, base_channels * 2, 3, 2, act=act),
CSPLayer(base_channels * 2, base_channels * 2, n=base_depth, depthwise=depthwise, act=act),
)
#-----------------------------------------------#
# 完成卷积之后,160, 160, 128 -> 80, 80, 256
# 完成CSPlayer之后,80, 80, 256 -> 80, 80, 256
#-----------------------------------------------#
self.dark3 = nn.Sequential(
Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
CSPLayer(base_channels * 4, base_channels * 4, n=base_depth * 3, depthwise=depthwise, act=act),
)
#-----------------------------------------------#
# 完成卷积之后,80, 80, 256 -> 40, 40, 512
# 完成CSPlayer之后,40, 40, 512 -> 40, 40, 512
#-----------------------------------------------#
self.dark4 = nn.Sequential(
Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
CSPLayer(base_channels * 8, base_channels * 8, n=base_depth * 3, depthwise=depthwise, act=act),
)
#-----------------------------------------------#
# 完成卷积之后,40, 40, 512 -> 20, 20, 1024
# 完成SPP之后,20, 20, 1024 -> 20, 20, 1024
# 完成CSPlayer之后,20, 20, 1024 -> 20, 20, 1024
#-----------------------------------------------#
self.dark5 = nn.Sequential(
Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
SPPBottleneck(base_channels * 16, base_channels * 16, activation=act),
CSPLayer(base_channels * 16, base_channels * 16, n=base_depth, shortcut=False, depthwise=depthwise, act=act),
)
def forward(self, x):
outputs = {}
x = self.stem(x)
outputs["stem"] = x
x = self.dark2(x)
outputs["dark2"] = x
#-----------------------------------------------#
# dark3的输出为80, 80, 256,是一个有效特征层
#-----------------------------------------------#
x = self.dark3(x)
outputs["dark3"] = x
#-----------------------------------------------#
# dark4的输出为40, 40, 512,是一个有效特征层
#-----------------------------------------------#
x = self.dark4(x)
outputs["dark4"] = x
#-----------------------------------------------#
# dark5的输出为20, 20, 1024,是一个有效特征层
#-----------------------------------------------#
x = self.dark5(x)
outputs["dark5"] = x
return {k: v for k, v in outputs.items() if k in self.out_features}
if __name__ == '__main__':
print(CSPDarknet(1, 1))
+247
View File
@@ -0,0 +1,247 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import torch
import torch.nn as nn
from .darknet import BaseConv, CSPDarknet, CSPLayer, DWConv
class YOLOXHead(nn.Module):
def __init__(self, num_classes, width = 1.0, in_channels = [256, 512, 1024], act = "silu", depthwise = False,):
super().__init__()
Conv = DWConv if depthwise else BaseConv
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
self.cls_preds = nn.ModuleList()
self.reg_preds = nn.ModuleList()
self.obj_preds = nn.ModuleList()
self.stems = nn.ModuleList()
for i in range(len(in_channels)):
self.stems.append(BaseConv(in_channels = int(in_channels[i] * width), out_channels = int(256 * width), ksize = 1, stride = 1, act = act))
self.cls_convs.append(nn.Sequential(*[
Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act),
Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act),
]))
self.cls_preds.append(
nn.Conv2d(in_channels = int(256 * width), out_channels = num_classes, kernel_size = 1, stride = 1, padding = 0)
)
self.reg_convs.append(nn.Sequential(*[
Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act),
Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act)
]))
self.reg_preds.append(
nn.Conv2d(in_channels = int(256 * width), out_channels = 4, kernel_size = 1, stride = 1, padding = 0)
)
self.obj_preds.append(
nn.Conv2d(in_channels = int(256 * width), out_channels = 1, kernel_size = 1, stride = 1, padding = 0)
)
def forward(self, inputs):
#---------------------------------------------------#
# inputs输入
# P3_out 80, 80, 256
# P4_out 40, 40, 512
# P5_out 20, 20, 1024
#---------------------------------------------------#
outputs = []
for k, x in enumerate(inputs):
#---------------------------------------------------#
# 利用1x1卷积进行通道整合
#---------------------------------------------------#
x = self.stems[k](x)
#---------------------------------------------------#
# 利用两个卷积标准化激活函数来进行特征提取
#---------------------------------------------------#
cls_feat = self.cls_convs[k](x)
#---------------------------------------------------#
# 判断特征点所属的种类
# 80, 80, num_classes
# 40, 40, num_classes
# 20, 20, num_classes
#---------------------------------------------------#
cls_output = self.cls_preds[k](cls_feat)
#---------------------------------------------------#
# 利用两个卷积标准化激活函数来进行特征提取
#---------------------------------------------------#
reg_feat = self.reg_convs[k](x)
#---------------------------------------------------#
# 特征点的回归系数
# reg_pred 80, 80, 4
# reg_pred 40, 40, 4
# reg_pred 20, 20, 4
#---------------------------------------------------#
reg_output = self.reg_preds[k](reg_feat)
#---------------------------------------------------#
# 判断特征点是否有对应的物体
# obj_pred 80, 80, 1
# obj_pred 40, 40, 1
# obj_pred 20, 20, 1
#---------------------------------------------------#
obj_output = self.obj_preds[k](reg_feat)
output = torch.cat([reg_output, obj_output, cls_output], 1)
outputs.append(output)
return outputs
class YOLOPAFPN(nn.Module):
def __init__(self, depth = 1.0, width = 1.0, in_features = ("dark3", "dark4", "dark5"), in_channels = [256, 512, 1024], depthwise = False, act = "silu"):
super().__init__()
Conv = DWConv if depthwise else BaseConv
self.backbone = CSPDarknet(depth, width, depthwise = depthwise, act = act)
self.in_features = in_features
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
#-------------------------------------------#
# 20, 20, 1024 -> 20, 20, 512
#-------------------------------------------#
self.lateral_conv0 = BaseConv(int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act)
#-------------------------------------------#
# 40, 40, 1024 -> 40, 40, 512
#-------------------------------------------#
self.C3_p4 = CSPLayer(
int(2 * in_channels[1] * width),
int(in_channels[1] * width),
round(3 * depth),
False,
depthwise = depthwise,
act = act,
)
#-------------------------------------------#
# 40, 40, 512 -> 40, 40, 256
#-------------------------------------------#
self.reduce_conv1 = BaseConv(int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act)
#-------------------------------------------#
# 80, 80, 512 -> 80, 80, 256
#-------------------------------------------#
self.C3_p3 = CSPLayer(
int(2 * in_channels[0] * width),
int(in_channels[0] * width),
round(3 * depth),
False,
depthwise = depthwise,
act = act,
)
#-------------------------------------------#
# 80, 80, 256 -> 40, 40, 256
#-------------------------------------------#
self.bu_conv2 = Conv(int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act)
#-------------------------------------------#
# 40, 40, 256 -> 40, 40, 512
#-------------------------------------------#
self.C3_n3 = CSPLayer(
int(2 * in_channels[0] * width),
int(in_channels[1] * width),
round(3 * depth),
False,
depthwise = depthwise,
act = act,
)
#-------------------------------------------#
# 40, 40, 512 -> 20, 20, 512
#-------------------------------------------#
self.bu_conv1 = Conv(int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act)
#-------------------------------------------#
# 20, 20, 1024 -> 20, 20, 1024
#-------------------------------------------#
self.C3_n4 = CSPLayer(
int(2 * in_channels[1] * width),
int(in_channels[2] * width),
round(3 * depth),
False,
depthwise = depthwise,
act = act,
)
def forward(self, input):
out_features = self.backbone.forward(input)
[feat1, feat2, feat3] = [out_features[f] for f in self.in_features]
#-------------------------------------------#
# 20, 20, 1024 -> 20, 20, 512
#-------------------------------------------#
P5 = self.lateral_conv0(feat3)
#-------------------------------------------#
# 20, 20, 512 -> 40, 40, 512
#-------------------------------------------#
P5_upsample = self.upsample(P5)
#-------------------------------------------#
# 40, 40, 512 + 40, 40, 512 -> 40, 40, 1024
#-------------------------------------------#
P5_upsample = torch.cat([P5_upsample, feat2], 1)
#-------------------------------------------#
# 40, 40, 1024 -> 40, 40, 512
#-------------------------------------------#
P5_upsample = self.C3_p4(P5_upsample)
#-------------------------------------------#
# 40, 40, 512 -> 40, 40, 256
#-------------------------------------------#
P4 = self.reduce_conv1(P5_upsample)
#-------------------------------------------#
# 40, 40, 256 -> 80, 80, 256
#-------------------------------------------#
P4_upsample = self.upsample(P4)
#-------------------------------------------#
# 80, 80, 256 + 80, 80, 256 -> 80, 80, 512
#-------------------------------------------#
P4_upsample = torch.cat([P4_upsample, feat1], 1)
#-------------------------------------------#
# 80, 80, 512 -> 80, 80, 256
#-------------------------------------------#
P3_out = self.C3_p3(P4_upsample)
#-------------------------------------------#
# 80, 80, 256 -> 40, 40, 256
#-------------------------------------------#
P3_downsample = self.bu_conv2(P3_out)
#-------------------------------------------#
# 40, 40, 256 + 40, 40, 256 -> 40, 40, 512
#-------------------------------------------#
P3_downsample = torch.cat([P3_downsample, P4], 1)
#-------------------------------------------#
# 40, 40, 256 -> 40, 40, 512
#-------------------------------------------#
P4_out = self.C3_n3(P3_downsample)
#-------------------------------------------#
# 40, 40, 512 -> 20, 20, 512
#-------------------------------------------#
P4_downsample = self.bu_conv1(P4_out)
#-------------------------------------------#
# 20, 20, 512 + 20, 20, 512 -> 20, 20, 1024
#-------------------------------------------#
P4_downsample = torch.cat([P4_downsample, P5], 1)
#-------------------------------------------#
# 20, 20, 1024 -> 20, 20, 1024
#-------------------------------------------#
P5_out = self.C3_n4(P4_downsample)
return (P3_out, P4_out, P5_out)
class YoloBody(nn.Module):
def __init__(self, num_classes, phi):
super().__init__()
depth_dict = {'nano': 0.33, 'tiny': 0.33, 's' : 0.33, 'm' : 0.67, 'l' : 1.00, 'x' : 1.33,}
width_dict = {'nano': 0.25, 'tiny': 0.375, 's' : 0.50, 'm' : 0.75, 'l' : 1.00, 'x' : 1.25,}
depth, width = depth_dict[phi], width_dict[phi]
depthwise = True if phi == 'nano' else False
self.backbone = YOLOPAFPN(depth, width, depthwise=depthwise)
self.head = YOLOXHead(num_classes, width, depthwise=depthwise)
def forward(self, x):
fpn_outs = self.backbone.forward(x)
outputs = self.head.forward(fpn_outs)
return outputs
+488
View File
@@ -0,0 +1,488 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import math
from copy import deepcopy
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
class IOUloss(nn.Module):
def __init__(self, reduction="none", loss_type="iou"):
super(IOUloss, self).__init__()
self.reduction = reduction
self.loss_type = loss_type
def forward(self, pred, target):
assert pred.shape[0] == target.shape[0]
pred = pred.view(-1, 4)
target = target.view(-1, 4)
tl = torch.max(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
br = torch.min(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
area_p = torch.prod(pred[:, 2:], 1)
area_g = torch.prod(target[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=1)
area_i = torch.prod(br - tl, 1) * en
area_u = area_p + area_g - area_i
iou = (area_i) / (area_u + 1e-16)
if self.loss_type == "iou":
loss = 1 - iou ** 2
elif self.loss_type == "giou":
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
area_c = torch.prod(c_br - c_tl, 1)
giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
loss = 1 - giou.clamp(min=-1.0, max=1.0)
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
return loss
class YOLOLoss(nn.Module):
def __init__(self, num_classes, fp16, strides=[8, 16, 32]):
super().__init__()
self.num_classes = num_classes
self.strides = strides
self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
self.iou_loss = IOUloss(reduction="none")
self.grids = [torch.zeros(1)] * len(strides)
self.fp16 = fp16
def forward(self, inputs, labels=None):
outputs = []
x_shifts = []
y_shifts = []
expanded_strides = []
#-----------------------------------------------#
# inputs [[batch_size, num_classes + 5, 20, 20]
# [batch_size, num_classes + 5, 40, 40]
# [batch_size, num_classes + 5, 80, 80]]
# outputs [[batch_size, 400, num_classes + 5]
# [batch_size, 1600, num_classes + 5]
# [batch_size, 6400, num_classes + 5]]
# x_shifts [[batch_size, 400]
# [batch_size, 1600]
# [batch_size, 6400]]
#-----------------------------------------------#
for k, (stride, output) in enumerate(zip(self.strides, inputs)):
output, grid = self.get_output_and_grid(output, k, stride)
x_shifts.append(grid[:, :, 0])
y_shifts.append(grid[:, :, 1])
expanded_strides.append(torch.ones_like(grid[:, :, 0]) * stride)
outputs.append(output)
return self.get_losses(x_shifts, y_shifts, expanded_strides, labels, torch.cat(outputs, 1))
def get_output_and_grid(self, output, k, stride):
grid = self.grids[k]
hsize, wsize = output.shape[-2:]
if grid.shape[2:4] != output.shape[2:4]:
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, hsize, wsize, 2).type(output.type())
self.grids[k] = grid
grid = grid.view(1, -1, 2)
output = output.flatten(start_dim=2).permute(0, 2, 1)
output[..., :2] = (output[..., :2] + grid.type_as(output)) * stride
output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
return output, grid
def get_losses(self, x_shifts, y_shifts, expanded_strides, labels, outputs):
#-----------------------------------------------#
# [batch, n_anchors_all, 4]
#-----------------------------------------------#
bbox_preds = outputs[:, :, :4]
#-----------------------------------------------#
# [batch, n_anchors_all, 1]
#-----------------------------------------------#
obj_preds = outputs[:, :, 4:5]
#-----------------------------------------------#
# [batch, n_anchors_all, n_cls]
#-----------------------------------------------#
cls_preds = outputs[:, :, 5:]
total_num_anchors = outputs.shape[1]
#-----------------------------------------------#
# x_shifts [1, n_anchors_all]
# y_shifts [1, n_anchors_all]
# expanded_strides [1, n_anchors_all]
#-----------------------------------------------#
x_shifts = torch.cat(x_shifts, 1).type_as(outputs)
y_shifts = torch.cat(y_shifts, 1).type_as(outputs)
expanded_strides = torch.cat(expanded_strides, 1).type_as(outputs)
cls_targets = []
reg_targets = []
obj_targets = []
fg_masks = []
num_fg = 0.0
for batch_idx in range(outputs.shape[0]):
num_gt = len(labels[batch_idx])
if num_gt == 0:
cls_target = outputs.new_zeros((0, self.num_classes))
reg_target = outputs.new_zeros((0, 4))
obj_target = outputs.new_zeros((total_num_anchors, 1))
fg_mask = outputs.new_zeros(total_num_anchors).bool()
else:
#-----------------------------------------------#
# gt_bboxes_per_image [num_gt, num_classes]
# gt_classes [num_gt]
# bboxes_preds_per_image [n_anchors_all, 4]
# cls_preds_per_image [n_anchors_all, num_classes]
# obj_preds_per_image [n_anchors_all, 1]
#-----------------------------------------------#
gt_bboxes_per_image = labels[batch_idx][..., :4].type_as(outputs)
gt_classes = labels[batch_idx][..., 4].type_as(outputs)
bboxes_preds_per_image = bbox_preds[batch_idx]
cls_preds_per_image = cls_preds[batch_idx]
obj_preds_per_image = obj_preds[batch_idx]
gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = self.get_assignments(
num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image,
expanded_strides, x_shifts, y_shifts,
)
torch.cuda.empty_cache()
num_fg += num_fg_img
cls_target = F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes).float() * pred_ious_this_matching.unsqueeze(-1)
obj_target = fg_mask.unsqueeze(-1)
reg_target = gt_bboxes_per_image[matched_gt_inds]
cls_targets.append(cls_target)
reg_targets.append(reg_target)
obj_targets.append(obj_target.type(cls_target.type()))
fg_masks.append(fg_mask)
cls_targets = torch.cat(cls_targets, 0)
reg_targets = torch.cat(reg_targets, 0)
obj_targets = torch.cat(obj_targets, 0)
fg_masks = torch.cat(fg_masks, 0)
num_fg = max(num_fg, 1)
loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum()
loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum()
loss_cls = (self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum()
reg_weight = 5.0
loss = reg_weight * loss_iou + loss_obj + loss_cls
return loss / num_fg
@torch.no_grad()
def get_assignments(self, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image, expanded_strides, x_shifts, y_shifts):
#-------------------------------------------------------#
# fg_mask [n_anchors_all]
# is_in_boxes_and_center [num_gt, len(fg_mask)]
#-------------------------------------------------------#
fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt)
#-------------------------------------------------------#
# fg_mask [n_anchors_all]
# bboxes_preds_per_image [fg_mask, 4]
# cls_preds_ [fg_mask, num_classes]
# obj_preds_ [fg_mask, 1]
#-------------------------------------------------------#
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
cls_preds_ = cls_preds_per_image[fg_mask]
obj_preds_ = obj_preds_per_image[fg_mask]
num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
#-------------------------------------------------------#
# pair_wise_ious [num_gt, fg_mask]
#-------------------------------------------------------#
pair_wise_ious = self.bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
#-------------------------------------------------------#
# cls_preds_ [num_gt, fg_mask, num_classes]
# gt_cls_per_image [num_gt, fg_mask, num_classes]
#-------------------------------------------------------#
if self.fp16:
with torch.cuda.amp.autocast(enabled=False):
cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)
pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)
else:
cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)
pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)
del cls_preds_
cost = pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center).float()
num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
return gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg
def bboxes_iou(self, bboxes_a, bboxes_b, xyxy=True):
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
raise IndexError
if xyxy:
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
else:
tl = torch.max(
(bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
)
br = torch.min(
(bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
)
area_a = torch.prod(bboxes_a[:, 2:], 1)
area_b = torch.prod(bboxes_b[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=2)
area_i = torch.prod(br - tl, 2) * en
return area_i / (area_a[:, None] + area_b - area_i)
def get_in_boxes_info(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt, center_radius = 2.5):
#-------------------------------------------------------#
# expanded_strides_per_image [n_anchors_all]
# x_centers_per_image [num_gt, n_anchors_all]
# x_centers_per_image [num_gt, n_anchors_all]
#-------------------------------------------------------#
expanded_strides_per_image = expanded_strides[0]
x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
#-------------------------------------------------------#
# gt_bboxes_per_image_x [num_gt, n_anchors_all]
#-------------------------------------------------------#
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
#-------------------------------------------------------#
# bbox_deltas [num_gt, n_anchors_all, 4]
#-------------------------------------------------------#
b_l = x_centers_per_image - gt_bboxes_per_image_l
b_r = gt_bboxes_per_image_r - x_centers_per_image
b_t = y_centers_per_image - gt_bboxes_per_image_t
b_b = gt_bboxes_per_image_b - y_centers_per_image
bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
#-------------------------------------------------------#
# is_in_boxes [num_gt, n_anchors_all]
# is_in_boxes_all [n_anchors_all]
#-------------------------------------------------------#
is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
#-------------------------------------------------------#
# center_deltas [num_gt, n_anchors_all, 4]
#-------------------------------------------------------#
c_l = x_centers_per_image - gt_bboxes_per_image_l
c_r = gt_bboxes_per_image_r - x_centers_per_image
c_t = y_centers_per_image - gt_bboxes_per_image_t
c_b = gt_bboxes_per_image_b - y_centers_per_image
center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
#-------------------------------------------------------#
# is_in_centers [num_gt, n_anchors_all]
# is_in_centers_all [n_anchors_all]
#-------------------------------------------------------#
is_in_centers = center_deltas.min(dim=-1).values > 0.0
is_in_centers_all = is_in_centers.sum(dim=0) > 0
#-------------------------------------------------------#
# is_in_boxes_anchor [n_anchors_all]
# is_in_boxes_and_center [num_gt, is_in_boxes_anchor]
#-------------------------------------------------------#
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
is_in_boxes_and_center = is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
return is_in_boxes_anchor, is_in_boxes_and_center
def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
#-------------------------------------------------------#
# cost [num_gt, fg_mask]
# pair_wise_ious [num_gt, fg_mask]
# gt_classes [num_gt]
# fg_mask [n_anchors_all]
# matching_matrix [num_gt, fg_mask]
#-------------------------------------------------------#
matching_matrix = torch.zeros_like(cost)
#------------------------------------------------------------#
# 选取iou最大的n_candidate_k个点
# 然后求和,判断应该有多少点用于该框预测
# topk_ious [num_gt, n_candidate_k]
# dynamic_ks [num_gt]
# matching_matrix [num_gt, fg_mask]
#------------------------------------------------------------#
n_candidate_k = min(10, pair_wise_ious.size(1))
topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
for gt_idx in range(num_gt):
#------------------------------------------------------------#
# 给每个真实框选取最小的动态k个点
#------------------------------------------------------------#
_, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
matching_matrix[gt_idx][pos_idx] = 1.0
del topk_ious, dynamic_ks, pos_idx
#------------------------------------------------------------#
# anchor_matching_gt [fg_mask]
#------------------------------------------------------------#
anchor_matching_gt = matching_matrix.sum(0)
if (anchor_matching_gt > 1).sum() > 0:
#------------------------------------------------------------#
# 当某一个特征点指向多个真实框的时候
# 选取cost最小的真实框。
#------------------------------------------------------------#
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
#------------------------------------------------------------#
# fg_mask_inboxes [fg_mask]
# num_fg为正样本的特征点个数
#------------------------------------------------------------#
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
num_fg = fg_mask_inboxes.sum().item()
#------------------------------------------------------------#
# 对fg_mask进行更新
#------------------------------------------------------------#
fg_mask[fg_mask.clone()] = fg_mask_inboxes
#------------------------------------------------------------#
# 获得特征点对应的物品种类
#------------------------------------------------------------#
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
gt_matched_classes = gt_classes[matched_gt_inds]
pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes]
return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
def is_parallel(model):
# Returns True if model is of type DP or DDP
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
def de_parallel(model):
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
return model.module if is_parallel(model) else model
def copy_attr(a, b, include=(), exclude=()):
# Copy attributes from b to a, options to only include [...] and to exclude [...]
for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
continue
else:
setattr(a, k, v)
class ModelEMA:
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
Keeps a moving average of everything in the model state_dict (parameters and buffers)
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
"""
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
# Create EMA
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
# if next(model.parameters()).device.type != 'cpu':
# self.ema.half() # FP16 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
p.requires_grad_(False)
def update(self, model):
# Update EMA parameters
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates)
msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
copy_attr(self.ema, model, include, exclude)
def weights_init(net, init_type='normal', init_gain = 0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and classname.find('Conv') != -1:
if init_type == 'normal':
torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
print('initialize network with %s type' % init_type)
net.apply(init_func)
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
if iters <= warmup_total_iters:
# lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
elif iters >= total_iters - no_aug_iter:
lr = min_lr
else:
lr = min_lr + 0.5 * (lr - min_lr) * (
1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
)
return lr
def step_lr(lr, decay_rate, step_size, iters):
if step_size < 1:
raise ValueError("step_size must above 1.")
n = iters // step_size
out_lr = lr * decay_rate ** n
return out_lr
if lr_decay_type == "cos":
warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
else:
decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
step_size = total_iters / step_num
func = partial(step_lr, lr, decay_rate, step_size)
return func
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
lr = lr_scheduler_func(epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
+180
View File
@@ -0,0 +1,180 @@
# -----------------------------------------------------------------------#
# predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能
# 整合到了一个py文件中,通过指定mode进行模式的修改。
# -----------------------------------------------------------------------#
import os
import time
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
from yolo import YOLO
def PREDICT(dataPath, resultPath, fileName):
yolo = YOLO()
# ----------------------------------------------------------------------------------------------------------#
# mode用于指定测试的模式:
# 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
# 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
# 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
# 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
# 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。
# 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。
# ----------------------------------------------------------------------------------------------------------#
mode = "predict"
# -------------------------------------------------------------------------#
# crop 指定了是否在单张图片预测后对目标进行截取
# count 指定了是否进行目标的计数
# crop、count仅在mode='predict'时有效
# -------------------------------------------------------------------------#
crop = False
count = False
# ----------------------------------------------------------------------------------------------------------#
# video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
# 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
# video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
# 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
# video_fps 用于保存的视频的fps
#
# video_path、video_save_path和video_fps仅在mode='video'时有效
# 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
# ----------------------------------------------------------------------------------------------------------#
video_path = 0
video_save_path = ""
video_fps = 25.0
# ----------------------------------------------------------------------------------------------------------#
# test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
# fps_image_path 用于指定测试的fps图片
#
# test_interval和fps_image_path仅在mode='fps'有效
# ----------------------------------------------------------------------------------------------------------#
test_interval = 100
fps_image_path = "img/street.jpg"
# -------------------------------------------------------------------------#
# dir_origin_path 指定了用于检测的图片的文件夹路径
# dir_save_path 指定了检测完图片的保存路径
#
# dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
# ----------------- --------------------------------------------------------#
dir_origin_path = "yolox-pytorch/img/"
dir_save_path = "img_out/"
# -------------------------------------------------------------------------#
# heatmap_save_path 热力图的保存路径,默认保存在model_data下
#
# heatmap_save_path仅在mode='heatmap'有效
# -------------------------------------------------------------------------#
heatmap_save_path = "model_data/heatmap_vision.png"
# -------------------------------------------------------------------------#
# simplify 使用Simplify onnx
# onnx_save_path 指定了onnx的保存路径
# -------------------------------------------------------------------------#
simplify = True
onnx_save_path = "model_data/models.onnx"
if mode == "predict":
'''
1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取topleftbottomright这四个值。
3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的topleft,bottomright这四个值
在原图上利用矩阵的方式进行截取。
4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
'''
imgPath = os.path.join(dataPath, fileName)
resultFilePath = os.path.join(resultPath, fileName)
image = Image.open(imgPath)
r_image, result = yolo.detect_image(image, crop=crop, count=count)
r_image.show()
# 保存图片
r_image.save(resultFilePath)
return result # 返回预测得到的分类名以及得分
elif mode == "video":
capture = cv2.VideoCapture(video_path)
if video_save_path != "":
fourcc = cv2.VideoWriter_fourcc(*'XVID')
size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
ref, frame = capture.read()
if not ref:
raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
fps = 0.0
while (True):
t1 = time.time()
# 读取某一帧
ref, frame = capture.read()
if not ref:
break
# 格式转变,BGRtoRGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 转变成Image
frame = Image.fromarray(np.uint8(frame))
# 进行检测
frame = np.array(yolo.detect_image(frame))
# RGBtoBGR满足opencv显示格式
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
fps = (fps + (1. / (time.time() - t1))) / 2
print("fps= %.2f" % (fps))
frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow("video", frame)
c = cv2.waitKey(1) & 0xff
if video_save_path != "":
out.write(frame)
if c == 27:
capture.release()
break
print("Video Detection Done!")
capture.release()
if video_save_path != "":
print("Save processed video to the path :" + video_save_path)
out.release()
cv2.destroyAllWindows()
elif mode == "fps":
img = Image.open(fps_image_path)
tact_time = yolo.get_FPS(img, test_interval)
print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1')
elif mode == "dir_predict":
img_names = os.listdir(dir_origin_path)
for img_name in tqdm(img_names):
if img_name.lower().endswith(
('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
image_path = os.path.join(dir_origin_path, img_name)
image = Image.open(image_path)
r_image = yolo.detect_image(image)
if not os.path.exists(dir_save_path):
os.makedirs(dir_save_path)
r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
elif mode == "heatmap":
while True:
img = input('Input image filename:')
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
continue
else:
yolo.detect_heatmap(image, heatmap_save_path)
elif mode == "export_onnx":
yolo.convert_to_onnx(simplify, onnx_save_path)
else:
raise AssertionError(
"Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.")
+9
View File
@@ -0,0 +1,9 @@
'''
Author: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
Date: 2023-04-01 21:34:17
LastEditors: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
LastEditTime: 2023-04-01 21:41:27
FilePath: /luyuetong-data/yolox-pytorch/quanju.py
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
'''
mingzi = ''
+9
View File
@@ -0,0 +1,9 @@
scipy==1.2.1
numpy==1.17.0
matplotlib==3.1.2
opencv_python==4.1.2.30
torch==1.2.0
torchvision==0.4.0
tqdm==4.60.0
Pillow==8.2.0
h5py==2.10.0
Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 180 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 152 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 152 KiB

+31
View File
@@ -0,0 +1,31 @@
#--------------------------------------------#
# 该部分代码用于看网络结构
#--------------------------------------------#
import torch
from thop import clever_format, profile
from torchsummary import summary
from nets.yolo import YoloBody
if __name__ == "__main__":
input_shape = [640, 640]
num_classes = 80
phi = 'l'
# 需要使用device来指定网络在GPU还是CPU运行
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
m = YoloBody(num_classes, phi).to(device)
summary(m, (3, input_shape[0], input_shape[1]))
dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)
flops, params = profile(m.to(device), (dummy_input, ), verbose=False)
#--------------------------------------------------------#
# flops * 2是因为profile没有将卷积作为两个operations
# 有些论文将卷积算乘法、加法两个operations。此时乘2
# 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2
# 本代码选择乘2,参考YOLOX。
#--------------------------------------------------------#
flops = flops * 2
flops, params = clever_format([flops, params], "%.3f")
print('Total GFLOPS: %s' % (flops))
print('Total params: %s' % (params))
+15
View File
@@ -0,0 +1,15 @@
from flask import Flask, jsonify, request, make_response, render_template
from PIL import Image
from yolo import YOLO
from predict import PREDICT
import calendar,time,os
import numpy as np
# 用np.load加载.npy文件
data = np.load('E:\Temp/train_64/train_64.npy')
# 查看数组的形状
print(data.shape)
+543
View File
@@ -0,0 +1,543 @@
# -------------------------------------#
# 对数据集进行训练
# -------------------------------------#
import datetime
import os
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from nets.yolo import YoloBody
from nets.yolo_training import (ModelEMA, YOLOLoss, get_lr_scheduler,
set_optimizer_lr, weights_init)
from utils.callbacks import LossHistory, EvalCallback
from utils.dataloader import YoloDataset, yolo_dataset_collate
from utils.utils import get_classes, show_config
from utils.utils_fit import fit_one_epoch
'''
训练自己的目标检测模型一定需要注意以下几点:
1、训练前仔细检查自己的格式是否满足要求,该库要求数据集格式为VOC格式,需要准备好的内容有输入图片和标签
输入图片为.jpg图片,无需固定大小,传入训练前会自动进行resize。
灰度图会自动转成RGB图片进行训练,无需自己修改。
输入图片如果后缀非jpg,需要自己批量转成jpg后再开始训练。
标签为.xml格式,文件中会有需要检测的目标信息,标签文件和输入图片文件相对应。
2、损失值的大小用于判断是否收敛,比较重要的是有收敛的趋势,即验证集损失不断下降,如果验证集损失基本上不改变的话,模型基本上就收敛了。
损失值的具体大小并没有什么意义,大和小只在于损失的计算方式,并不是接近于0才好。如果想要让损失好看点,可以直接到对应的损失函数里面除上10000。
训练过程中的损失值会保存在logs文件夹下的loss_%Y_%m_%d_%H_%M_%S文件夹中
3、训练好的权值文件保存在logs文件夹中,每个训练世代(Epoch)包含若干训练步长(Step),每个训练步长(Step)进行一次梯度下降。
如果只是训练了几个Step是不会保存的,Epoch和Step的概念要捋清楚一下。
'''
if __name__ == "__main__":
# ---------------------------------#
# Cuda 是否使用Cuda
# 没有GPU可以设置成False
# ---------------------------------#
Cuda = True
# ---------------------------------------------------------------------#
# distributed 用于指定是否使用单机多卡分布式运行
# 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。
# Windows系统下默认使用DP模式调用所有显卡,不支持DDP。
# DP模式:
# 设置 distributed = False
# 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python train.py
# DDP模式:
# 设置 distributed = True
# 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py
# ---------------------------------------------------------------------#
distributed = False
# ---------------------------------------------------------------------#
# sync_bn 是否使用sync_bnDDP模式多卡可用
# ---------------------------------------------------------------------#
sync_bn = False
# ---------------------------------------------------------------------#
# fp16 是否使用混合精度训练
# 可减少约一半的显存、需要pytorch1.7.1以上
# ---------------------------------------------------------------------#
fp16 = False
# ---------------------------------------------------------------------#
# classes_path 指向model_data下的txt,与自己训练的数据集相关
# 训练前一定要修改classes_path,使其对应自己的数据集
# ---------------------------------------------------------------------#
classes_path = 'model_data/voc_classes.txt'
# ----------------------------------------------------------------------------------------------------------------------------#
# 权值文件的下载请看README,可以通过网盘下载。模型的 预训练权重 对不同数据集是通用的,因为特征是通用的。
# 模型的 预训练权重 比较重要的部分是 主干特征提取网络的权值部分,用于进行特征提取。
# 预训练权重对于99%的情况都必须要用,不用的话主干部分的权值太过随机,特征提取效果不明显,网络训练的结果也不会好
#
# 如果训练过程中存在中断训练的操作,可以将model_path设置成logs文件夹下的权值文件,将已经训练了一部分的权值再次载入。
# 同时修改下方的 冻结阶段 或者 解冻阶段 的参数,来保证模型epoch的连续性。
#
# 当model_path = ''的时候不加载整个模型的权值。
#
# 此处使用的是整个模型的权重,因此是在train.py进行加载的。
# 如果想要让模型从0开始训练,则设置model_path = '',下面的Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。
#
# 一般来讲,网络从0开始的训练效果会很差,因为权值太过随机,特征提取效果不明显,因此非常、非常、非常不建议大家从0开始训练!
# 从0开始训练有两个方案:
# 1、得益于Mosaic数据增强方法强大的数据增强能力,将UnFreeze_Epoch设置的较大(300及以上)、batch较大(16及以上)、数据较多(万以上)的情况下,
# 可以设置mosaic=True,直接随机初始化参数开始训练,但得到的效果仍然不如有预训练的情况。(像COCO这样的大数据集可以这样做)
# 2、了解imagenet数据集,首先训练分类模型,获得网络的主干部分权值,分类模型的 主干部分 和该模型通用,基于此进行训练。
# ----------------------------------------------------------------------------------------------------------------------------#
model_path = 'model_data/yolox_s.pth' # ..
# ------------------------------------------------------#
# input_shape 输入的shape大小,一定要是32的倍数
# ------------------------------------------------------#
input_shape = [640, 640]
# ------------------------------------------------------#
# 所使用的YoloX的版本。nano、tiny、s、m、l、x
# ------------------------------------------------------#
phi = 's' # ..
# ------------------------------------------------------------------#
# mosaic 马赛克数据增强。
# mosaic_prob 每个step有多少概率使用mosaic数据增强,默认50%。
#
# mixup 是否使用mixup数据增强,仅在mosaic=True时有效。
# 只会对mosaic增强后的图片进行mixup的处理。
# mixup_prob 有多少概率在mosaic后使用mixup数据增强,默认50%。
# 总的mixup概率为mosaic_prob * mixup_prob。
#
# special_aug_ratio 参考YoloX,由于Mosaic生成的训练图片,远远脱离自然图片的真实分布。
# 当mosaic=True时,本代码会在special_aug_ratio范围内开启mosaic。
# 默认为前70%个epoch,100个世代会开启70个世代。
#
# 余弦退火算法的参数放到下面的lr_decay_type中设置
# ------------------------------------------------------------------#
mosaic = True
mosaic_prob = 0.5
mixup = True
mixup_prob = 0.5
special_aug_ratio = 0.7
# ----------------------------------------------------------------------------------------------------------------------------#
# 训练分为两个阶段,分别是冻结阶段和解冻阶段。设置冻结阶段是为了满足机器性能不足的同学的训练需求。
# 冻结训练需要的显存较小,显卡非常差的情况下,可设置Freeze_Epoch等于UnFreeze_EpochFreeze_Train = True,此时仅仅进行冻结训练。
#
# 在此提供若干参数设置建议,各位训练者根据自己的需求进行灵活调整:
# (一)从整个模型的预训练权重开始训练:
# Adam
# Init_Epoch = 0Freeze_Epoch = 50UnFreeze_Epoch = 100Freeze_Train = Trueoptimizer_type = 'adam'Init_lr = 1e-3weight_decay = 0。(冻结)
# Init_Epoch = 0UnFreeze_Epoch = 100Freeze_Train = Falseoptimizer_type = 'adam'Init_lr = 1e-3weight_decay = 0。(不冻结)
# SGD
# Init_Epoch = 0Freeze_Epoch = 50UnFreeze_Epoch = 300Freeze_Train = Trueoptimizer_type = 'sgd'Init_lr = 1e-2weight_decay = 5e-4。(冻结)
# Init_Epoch = 0UnFreeze_Epoch = 300Freeze_Train = Falseoptimizer_type = 'sgd'Init_lr = 1e-2weight_decay = 5e-4。(不冻结)
# 其中:UnFreeze_Epoch可以在100-300之间调整。
# (二)从0开始训练:
# Init_Epoch = 0UnFreeze_Epoch >= 300Unfreeze_batch_size >= 16Freeze_Train = False(不冻结训练)
# 其中:UnFreeze_Epoch尽量不小于300。optimizer_type = 'sgd'Init_lr = 1e-2mosaic = True。
# (三)batch_size的设置:
# 在显卡能够接受的范围内,以大为好。显存不足与数据集大小无关,提示显存不足(OOM或者CUDA out of memory)请调小batch_size。
# 受到BatchNorm层影响,batch_size最小为2,不能为1。
# 正常情况下Freeze_batch_size建议为Unfreeze_batch_size的1-2倍。不建议设置的差距过大,因为关系到学习率的自动调整。
# ----------------------------------------------------------------------------------------------------------------------------#
# ------------------------------------------------------------------#
# 冻结阶段训练参数
# 此时模型的主干被冻结了,特征提取网络不发生改变
# 占用的显存较小,仅对网络进行微调
# Init_Epoch 模型当前开始的训练世代,其值可以大于Freeze_Epoch,如设置:
# Init_Epoch = 60、Freeze_Epoch = 50、UnFreeze_Epoch = 100
# 会跳过冻结阶段,直接从60代开始,并调整对应的学习率。
# (断点续练时使用)
# Freeze_Epoch 模型冻结训练的Freeze_Epoch
# (当Freeze_Train=False时失效)
# Freeze_batch_size 模型冻结训练的batch_size
# (当Freeze_Train=False时失效)
# ------------------------------------------------------------------#
Init_Epoch = 0 # ..
Freeze_Epoch = 50 # ..
Freeze_batch_size = 16 # ..
# ------------------------------------------------------------------#
# 解冻阶段训练参数
# 此时模型的主干不被冻结了,特征提取网络会发生改变
# 占用的显存较大,网络所有的参数都会发生改变
# UnFreeze_Epoch 模型总共训练的epoch
# SGD需要更长的时间收敛,因此设置较大的UnFreeze_Epoch
# Adam可以使用相对较小的UnFreeze_Epoch
# Unfreeze_batch_size 模型在解冻后的batch_size
# ------------------------------------------------------------------#
UnFreeze_Epoch = 300 # ..
Unfreeze_batch_size = 8 # ..
# ------------------------------------------------------------------#
# Freeze_Train 是否进行冻结训练
# 默认先冻结主干训练后解冻训练。
# ------------------------------------------------------------------#
Freeze_Train = True # ..
# ------------------------------------------------------------------#
# 其它训练参数:学习率、优化器、学习率下降有关
# ------------------------------------------------------------------#
# ------------------------------------------------------------------#
# Init_lr 模型的最大学习率
# Min_lr 模型的最小学习率,默认为最大学习率的0.01
# ------------------------------------------------------------------#
Init_lr = 1e-2
Min_lr = Init_lr * 0.01
# ------------------------------------------------------------------#
# optimizer_type 使用到的优化器种类,可选的有adam、sgd
# 当使用Adam优化器时建议设置 Init_lr=1e-3
# 当使用SGD优化器时建议设置 Init_lr=1e-2
# momentum 优化器内部使用到的momentum参数
# weight_decay 权值衰减,可防止过拟合
# adam会导致weight_decay错误,使用adam时建议设置为0。
# ------------------------------------------------------------------#
optimizer_type = "sgd" #
momentum = 0.937
weight_decay = 5e-4
# ------------------------------------------------------------------#
# lr_decay_type 使用到的学习率下降方式,可选的有step、cos
# ------------------------------------------------------------------#
lr_decay_type = "cos"
# ------------------------------------------------------------------#
# save_period 多少个epoch保存一次权值
# ------------------------------------------------------------------#
save_period = 10
# ------------------------------------------------------------------#
# save_dir 权值与日志文件保存的文件夹
# ------------------------------------------------------------------#
save_dir = 'logs'
# ------------------------------------------------------------------#
# eval_flag 是否在训练时进行评估,评估对象为验证集
# 安装pycocotools库后,评估体验更佳。
# eval_period 代表多少个epoch评估一次,不建议频繁的评估
# 评估需要消耗较多的时间,频繁评估会导致训练非常慢
# 此处获得的mAP会与get_map.py获得的会有所不同,原因有二:
# (一)此处获得的mAP为验证集的mAP。
# (二)此处设置评估参数较为保守,目的是加快评估速度。
# ------------------------------------------------------------------#
eval_flag = True
eval_period = 10
# ------------------------------------------------------------------#
# num_workers 用于设置是否使用多线程读取数据
# 开启后会加快数据读取速度,但是会占用更多内存
# 内存较小的电脑可以设置为2或者0
# ------------------------------------------------------------------#
num_workers = 4
# ----------------------------------------------------#
# 获得图片路径和标签
# ----------------------------------------------------#
train_annotation_path = '2007_train.txt'
val_annotation_path = '2007_val.txt'
# ------------------------------------------------------#
# 设置用到的显卡
# ------------------------------------------------------#
ngpus_per_node = torch.cuda.device_count()
if distributed:
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
device = torch.device("cuda", local_rank)
if local_rank == 0:
print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...")
print("Gpu Device Count : ", ngpus_per_node)
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
local_rank = 0
rank = 0
# ----------------------------------------------------#
# 获取classes和anchor
# ----------------------------------------------------#
class_names, num_classes = get_classes(classes_path)
# ------------------------------------------------------#
# 创建yolo模型
# ------------------------------------------------------#
model = YoloBody(num_classes, phi)
weights_init(model)
if model_path != '':
# ------------------------------------------------------#
# 权值文件请看README,百度网盘下载
# ------------------------------------------------------#
if local_rank == 0:
print('Load weights {}.'.format(model_path))
# ------------------------------------------------------#
# 根据预训练权重的Key和模型的Key进行加载
# ------------------------------------------------------#
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path, map_location=device)
load_key, no_load_key, temp_dict = [], [], {}
for k, v in pretrained_dict.items():
if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
temp_dict[k] = v
load_key.append(k)
else:
no_load_key.append(k)
model_dict.update(temp_dict)
model.load_state_dict(model_dict)
# ------------------------------------------------------#
# 显示没有匹配上的Key
# ------------------------------------------------------#
if local_rank == 0:
print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key))
print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key))
print("\n\033[1;33;44m温馨提示,head部分没有载入是正常现象,Backbone部分没有载入是错误的。\033[0m")
# ----------------------#
# 获得损失函数
# ----------------------#
yolo_loss = YOLOLoss(num_classes, fp16)
# ----------------------#
# 记录Loss
# ----------------------#
if local_rank == 0:
time_str = datetime.datetime.strftime(datetime.datetime.now(), '%Y_%m_%d_%H_%M_%S')
log_dir = os.path.join(save_dir, "loss_" + str(time_str))
loss_history = LossHistory(log_dir, model, input_shape=input_shape)
else:
loss_history = None
# ------------------------------------------------------------------#
# torch 1.2不支持amp,建议使用torch 1.7.1及以上正确使用fp16
# 因此torch1.2这里显示"could not be resolve"
# ------------------------------------------------------------------#
if fp16:
from torch.cuda.amp import GradScaler as GradScaler
scaler = GradScaler()
else:
scaler = None
model_train = model.train()
# ----------------------------#
# 多卡同步Bn
# ----------------------------#
if sync_bn and ngpus_per_node > 1 and distributed:
model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train)
elif sync_bn:
print("Sync_bn is not support in one gpu or not distributed.")
if Cuda:
if distributed:
# ----------------------------#
# 多卡平行运行
# ----------------------------#
model_train = model_train.cuda(local_rank)
model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank],
find_unused_parameters=True)
else:
model_train = torch.nn.DataParallel(model)
cudnn.benchmark = True
model_train = model_train.cuda()
# ----------------------------#
# 权值平滑
# ----------------------------#
ema = ModelEMA(model_train)
# ---------------------------#
# 读取数据集对应的txt
# ---------------------------#
with open(train_annotation_path, encoding='utf-8') as f:
train_lines = f.readlines()
with open(val_annotation_path, encoding='utf-8') as f:
val_lines = f.readlines()
num_train = len(train_lines)
num_val = len(val_lines)
if local_rank == 0:
show_config(
classes_path=classes_path, model_path=model_path, input_shape=input_shape, \
Init_Epoch=Init_Epoch, Freeze_Epoch=Freeze_Epoch, UnFreeze_Epoch=UnFreeze_Epoch,
Freeze_batch_size=Freeze_batch_size, Unfreeze_batch_size=Unfreeze_batch_size, Freeze_Train=Freeze_Train, \
Init_lr=Init_lr, Min_lr=Min_lr, optimizer_type=optimizer_type, momentum=momentum,
lr_decay_type=lr_decay_type, \
save_period=save_period, save_dir=save_dir, num_workers=num_workers, num_train=num_train, num_val=num_val
)
# ---------------------------------------------------------#
# 总训练世代指的是遍历全部数据的总次数
# 总训练步长指的是梯度下降的总次数
# 每个训练世代包含若干训练步长,每个训练步长进行一次梯度下降。
# 此处仅建议最低训练世代,上不封顶,计算时只考虑了解冻部分
# ----------------------------------------------------------#
wanted_step = 5e4 if optimizer_type == "sgd" else 1.5e4
total_step = num_train // Unfreeze_batch_size * UnFreeze_Epoch
if total_step <= wanted_step:
if num_train // Unfreeze_batch_size == 0:
raise ValueError('数据集过小,无法进行训练,请扩充数据集。')
wanted_epoch = wanted_step // (num_train // Unfreeze_batch_size) + 1
print("\n\033[1;33;44m[Warning] 使用%s优化器时,建议将训练总步长设置到%d以上。\033[0m" % (
optimizer_type, wanted_step))
print(
"\033[1;33;44m[Warning] 本次运行的总训练数据量为%dUnfreeze_batch_size为%d,共训练%d个Epoch,计算出总训练步长为%d\033[0m" % (
num_train, Unfreeze_batch_size, UnFreeze_Epoch, total_step))
print("\033[1;33;44m[Warning] 由于总训练步长为%d,小于建议总步长%d,建议设置总世代为%d\033[0m" % (
total_step, wanted_step, wanted_epoch))
# ------------------------------------------------------#
# 主干特征提取网络特征通用,冻结训练可以加快训练速度
# 也可以在训练初期防止权值被破坏。
# Init_Epoch为起始世代
# Freeze_Epoch为冻结训练的世代
# UnFreeze_Epoch总训练世代
# 提示OOM或者显存不足请调小Batch_size
# ------------------------------------------------------#
if True:
UnFreeze_flag = False
# ------------------------------------#
# 冻结一定部分训练
# ------------------------------------#
if Freeze_Train:
for param in model.backbone.parameters():
param.requires_grad = False
# -------------------------------------------------------------------#
# 如果不冻结训练的话,直接设置batch_size为Unfreeze_batch_size
# -------------------------------------------------------------------#
batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size
# -------------------------------------------------------------------#
# 判断当前batch_size,自适应调整学习率
# -------------------------------------------------------------------#
nbs = 64
lr_limit_max = 1e-3 if optimizer_type == 'adam' else 5e-2
lr_limit_min = 3e-4 if optimizer_type == 'adam' else 5e-4
Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
# ---------------------------------------#
# 根据optimizer_type选择优化器
# ---------------------------------------#
pg0, pg1, pg2 = [], [], []
for k, v in model.named_modules():
if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
pg2.append(v.bias)
if isinstance(v, nn.BatchNorm2d) or "bn" in k:
pg0.append(v.weight)
elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
pg1.append(v.weight)
optimizer = {
'adam': optim.Adam(pg0, Init_lr_fit, betas=(momentum, 0.999)),
'sgd': optim.SGD(pg0, Init_lr_fit, momentum=momentum, nesterov=True)
}[optimizer_type]
optimizer.add_param_group({"params": pg1, "weight_decay": weight_decay})
optimizer.add_param_group({"params": pg2})
# ---------------------------------------#
# 获得学习率下降的公式
# ---------------------------------------#
lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
# ---------------------------------------#
# 判断每一个世代的长度
# ---------------------------------------#
epoch_step = num_train // batch_size
epoch_step_val = num_val // batch_size
if epoch_step == 0 or epoch_step_val == 0:
raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
if ema:
ema.updates = epoch_step * Init_Epoch
# ---------------------------------------#
# 构建数据集加载器。
# ---------------------------------------#
train_dataset = YoloDataset(train_lines, input_shape, num_classes, epoch_length=UnFreeze_Epoch, \
mosaic=mosaic, mixup=mixup, mosaic_prob=mosaic_prob, mixup_prob=mixup_prob,
train=True, special_aug_ratio=special_aug_ratio)
val_dataset = YoloDataset(val_lines, input_shape, num_classes, epoch_length=UnFreeze_Epoch, \
mosaic=False, mixup=False, mosaic_prob=0, mixup_prob=0, train=False,
special_aug_ratio=0)
if distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, )
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, )
batch_size = batch_size // ngpus_per_node
shuffle = False
else:
train_sampler = None
val_sampler = None
shuffle = True
gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate, sampler=train_sampler)
gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler)
# ----------------------#
# 记录eval的map曲线
# ----------------------#
if local_rank == 0:
eval_callback = EvalCallback(model, input_shape, class_names, num_classes, val_lines, log_dir, Cuda, \
eval_flag=eval_flag, period=eval_period)
else:
eval_callback = None
# ---------------------------------------#
# 开始模型训练
# ---------------------------------------#
for epoch in range(Init_Epoch, UnFreeze_Epoch):
# ---------------------------------------#
# 如果模型有冻结学习部分
# 则解冻,并设置参数
# ---------------------------------------#
if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train:
batch_size = Unfreeze_batch_size
# -------------------------------------------------------------------#
# 判断当前batch_size,自适应调整学习率
# -------------------------------------------------------------------#
nbs = 64
lr_limit_max = 1e-3 if optimizer_type == 'adam' else 5e-2
lr_limit_min = 3e-4 if optimizer_type == 'adam' else 5e-4
Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
# ---------------------------------------#
# 获得学习率下降的公式
# ---------------------------------------#
lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
for param in model.backbone.parameters():
param.requires_grad = True
epoch_step = num_train // batch_size
epoch_step_val = num_val // batch_size
if epoch_step == 0 or epoch_step_val == 0:
raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
if distributed:
batch_size = batch_size // ngpus_per_node
if ema:
ema.updates = epoch_step * epoch
gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate, sampler=train_sampler)
gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler)
UnFreeze_flag = True
gen.dataset.epoch_now = epoch
gen_val.dataset.epoch_now = epoch
if distributed:
train_sampler.set_epoch(epoch)
set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step,
epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir,
local_rank)
if distributed:
dist.barrier()
if local_rank == 0:
loss_history.writer.close()
Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 67 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 126 KiB

+1
View File
@@ -0,0 +1 @@
http://localhost/img/000009.jpg
Binary file not shown.

After

Width:  |  Height:  |  Size: 174 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 145 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 146 KiB

+1
View File
@@ -0,0 +1 @@
#
+227
View File
@@ -0,0 +1,227 @@
import os
import torch
import matplotlib
matplotlib.use('Agg')
import scipy.signal
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
import shutil
import numpy as np
from PIL import Image
from tqdm import tqdm
from .utils import cvtColor, preprocess_input, resize_image
from .utils_bbox import decode_outputs, non_max_suppression
from .utils_map import get_coco_map, get_map
class LossHistory():
def __init__(self, log_dir, model, input_shape):
self.log_dir = log_dir
self.losses = []
self.val_loss = []
os.makedirs(self.log_dir)
self.writer = SummaryWriter(self.log_dir)
try:
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
self.writer.add_graph(model, dummy_input)
except:
pass
def append_loss(self, epoch, loss, val_loss):
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.losses.append(loss)
self.val_loss.append(val_loss)
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
f.write(str(loss))
f.write("\n")
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
f.write(str(val_loss))
f.write("\n")
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)
self.loss_plot()
def loss_plot(self):
iters = range(len(self.losses))
plt.figure()
plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
try:
if len(self.losses) < 25:
num = 5
else:
num = 15
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
except:
pass
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
plt.cla()
plt.close("all")
class EvalCallback():
def __init__(self, net, input_shape, class_names, num_classes, val_lines, log_dir, cuda, \
map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
super(EvalCallback, self).__init__()
self.net = net
self.input_shape = input_shape
self.class_names = class_names
self.num_classes = num_classes
self.val_lines = val_lines
self.log_dir = log_dir
self.cuda = cuda
self.map_out_path = map_out_path
self.max_boxes = max_boxes
self.confidence = confidence
self.nms_iou = nms_iou
self.letterbox_image = letterbox_image
self.MINOVERLAP = MINOVERLAP
self.eval_flag = eval_flag
self.period = period
self.maps = [0]
self.epoches = [0]
if self.eval_flag:
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
f.write(str(0))
f.write("\n")
def get_map_txt(self, image_id, image, class_names, map_out_path):
f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
image_shape = np.array(np.shape(image)[0:2])
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
#---------------------------------------------------------#
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
outputs = self.net(images)
outputs = decode_outputs(outputs, self.input_shape)
#---------------------------------------------------------#
# 将预测框进行堆叠,然后进行非极大抑制
#---------------------------------------------------------#
results = non_max_suppression(outputs, self.num_classes, self.input_shape,
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
if results[0] is None:
return
top_label = np.array(results[0][:, 6], dtype = 'int32')
top_conf = results[0][:, 4] * results[0][:, 5]
top_boxes = results[0][:, :4]
top_100 = np.argsort(top_conf)[::-1][:self.max_boxes]
top_boxes = top_boxes[top_100]
top_conf = top_conf[top_100]
top_label = top_label[top_100]
for i, c in list(enumerate(top_label)):
predicted_class = self.class_names[int(c)]
box = top_boxes[i]
score = str(top_conf[i])
top, left, bottom, right = box
if predicted_class not in class_names:
continue
f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
f.close()
return
def on_epoch_end(self, epoch, model_eval):
if epoch % self.period == 0 and self.eval_flag:
self.net = model_eval
if not os.path.exists(self.map_out_path):
os.makedirs(self.map_out_path)
if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
os.makedirs(os.path.join(self.map_out_path, "detection-results"))
print("Get map.")
for annotation_line in tqdm(self.val_lines):
line = annotation_line.split()
image_id = os.path.basename(line[0]).split('.')[0]
#------------------------------#
# 读取图像并转换成RGB图像
#------------------------------#
image = Image.open(line[0])
#------------------------------#
# 获得预测框
#------------------------------#
gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
#------------------------------#
# 获得预测txt
#------------------------------#
self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
#------------------------------#
# 获得真实框txt
#------------------------------#
with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
for box in gt_boxes:
left, top, right, bottom, obj = box
obj_name = self.class_names[obj]
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
print("Calculate Map.")
try:
temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
except:
temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
self.maps.append(temp_map)
self.epoches.append(epoch)
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
f.write(str(temp_map))
f.write("\n")
plt.figure()
plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Map %s'%str(self.MINOVERLAP))
plt.title('A Map Curve')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
plt.cla()
plt.close("all")
print("Get map done.")
shutil.rmtree(self.map_out_path)
+374
View File
@@ -0,0 +1,374 @@
from random import sample, shuffle
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from utils.utils import cvtColor, preprocess_input
class YoloDataset(Dataset):
def __init__(self, annotation_lines, input_shape, num_classes, epoch_length, \
mosaic, mixup, mosaic_prob, mixup_prob, train, special_aug_ratio = 0.7):
super(YoloDataset, self).__init__()
self.annotation_lines = annotation_lines
self.input_shape = input_shape
self.num_classes = num_classes
self.epoch_length = epoch_length
self.mosaic = mosaic
self.mosaic_prob = mosaic_prob
self.mixup = mixup
self.mixup_prob = mixup_prob
self.train = train
self.special_aug_ratio = special_aug_ratio
self.epoch_now = -1
self.length = len(self.annotation_lines)
def __len__(self):
return self.length
def __getitem__(self, index):
index = index % self.length
#---------------------------------------------------#
# 训练时进行数据的随机增强
# 验证时不进行数据的随机增强
#---------------------------------------------------#
if self.mosaic and self.rand() < self.mosaic_prob and self.epoch_now < self.epoch_length * self.special_aug_ratio:
lines = sample(self.annotation_lines, 3)
lines.append(self.annotation_lines[index])
shuffle(lines)
image, box = self.get_random_data_with_Mosaic(lines, self.input_shape)
if self.mixup and self.rand() < self.mixup_prob:
lines = sample(self.annotation_lines, 1)
image_2, box_2 = self.get_random_data(lines[0], self.input_shape, random = self.train)
image, box = self.get_random_data_with_MixUp(image, box, image_2, box_2)
else:
image, box = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train)
image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
box = np.array(box, dtype=np.float32)
if len(box) != 0:
box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2
return image, box
def rand(self, a=0, b=1):
return np.random.rand()*(b-a) + a
def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
line = annotation_line.split()
#------------------------------#
# 读取图像并转换成RGB图像
#------------------------------#
image = Image.open(line[0])
image = cvtColor(image)
#------------------------------#
# 获得图像的高宽与目标高宽
#------------------------------#
iw, ih = image.size
h, w = input_shape
#------------------------------#
# 获得预测框
#------------------------------#
box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
if not random:
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
dx = (w-nw)//2
dy = (h-nh)//2
#---------------------------------#
# 将图像多余的部分加上灰条
#---------------------------------#
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', (w,h), (128,128,128))
new_image.paste(image, (dx, dy))
image_data = np.array(new_image, np.float32)
#---------------------------------#
# 对真实框进行调整
#---------------------------------#
if len(box)>0:
np.random.shuffle(box)
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
box[:, 0:2][box[:, 0:2]<0] = 0
box[:, 2][box[:, 2]>w] = w
box[:, 3][box[:, 3]>h] = h
box_w = box[:, 2] - box[:, 0]
box_h = box[:, 3] - box[:, 1]
box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
return image_data, box
#------------------------------------------#
# 对图像进行缩放并且进行长和宽的扭曲
#------------------------------------------#
new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
scale = self.rand(.25, 2)
if new_ar < 1:
nh = int(scale*h)
nw = int(nh*new_ar)
else:
nw = int(scale*w)
nh = int(nw/new_ar)
image = image.resize((nw,nh), Image.BICUBIC)
#------------------------------------------#
# 将图像多余的部分加上灰条
#------------------------------------------#
dx = int(self.rand(0, w-nw))
dy = int(self.rand(0, h-nh))
new_image = Image.new('RGB', (w,h), (128,128,128))
new_image.paste(image, (dx, dy))
image = new_image
#------------------------------------------#
# 翻转图像
#------------------------------------------#
flip = self.rand()<.5
if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
image_data = np.array(image, np.uint8)
#---------------------------------#
# 对图像进行色域变换
# 计算色域变换的参数
#---------------------------------#
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
#---------------------------------#
# 将图像转到HSV上
#---------------------------------#
hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
dtype = image_data.dtype
#---------------------------------#
# 应用变换
#---------------------------------#
x = np.arange(0, 256, dtype=r.dtype)
lut_hue = ((x * r[0]) % 180).astype(dtype)
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
#---------------------------------#
# 对真实框进行调整
#---------------------------------#
if len(box)>0:
np.random.shuffle(box)
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
if flip: box[:, [0,2]] = w - box[:, [2,0]]
box[:, 0:2][box[:, 0:2]<0] = 0
box[:, 2][box[:, 2]>w] = w
box[:, 3][box[:, 3]>h] = h
box_w = box[:, 2] - box[:, 0]
box_h = box[:, 3] - box[:, 1]
box = box[np.logical_and(box_w>1, box_h>1)]
return image_data, box
def merge_bboxes(self, bboxes, cutx, cuty):
merge_bbox = []
for i in range(len(bboxes)):
for box in bboxes[i]:
tmp_box = []
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
if i == 0:
if y1 > cuty or x1 > cutx:
continue
if y2 >= cuty and y1 <= cuty:
y2 = cuty
if x2 >= cutx and x1 <= cutx:
x2 = cutx
if i == 1:
if y2 < cuty or x1 > cutx:
continue
if y2 >= cuty and y1 <= cuty:
y1 = cuty
if x2 >= cutx and x1 <= cutx:
x2 = cutx
if i == 2:
if y2 < cuty or x2 < cutx:
continue
if y2 >= cuty and y1 <= cuty:
y1 = cuty
if x2 >= cutx and x1 <= cutx:
x1 = cutx
if i == 3:
if y1 > cuty or x2 < cutx:
continue
if y2 >= cuty and y1 <= cuty:
y2 = cuty
if x2 >= cutx and x1 <= cutx:
x1 = cutx
tmp_box.append(x1)
tmp_box.append(y1)
tmp_box.append(x2)
tmp_box.append(y2)
tmp_box.append(box[-1])
merge_bbox.append(tmp_box)
return merge_bbox
def get_random_data_with_Mosaic(self, annotation_line, input_shape, jitter=0.3, hue=.1, sat=0.7, val=0.4):
h, w = input_shape
min_offset_x = self.rand(0.3, 0.7)
min_offset_y = self.rand(0.3, 0.7)
image_datas = []
box_datas = []
index = 0
for line in annotation_line:
#---------------------------------#
# 每一行进行分割
#---------------------------------#
line_content = line.split()
#---------------------------------#
# 打开图片
#---------------------------------#
image = Image.open(line_content[0])
image = cvtColor(image)
#---------------------------------#
# 图片的大小
#---------------------------------#
iw, ih = image.size
#---------------------------------#
# 保存框的位置
#---------------------------------#
box = np.array([np.array(list(map(int,box.split(',')))) for box in line_content[1:]])
#---------------------------------#
# 是否翻转图片
#---------------------------------#
flip = self.rand()<.5
if flip and len(box)>0:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
box[:, [0,2]] = iw - box[:, [2,0]]
#------------------------------------------#
# 对图像进行缩放并且进行长和宽的扭曲
#------------------------------------------#
new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
scale = self.rand(.4, 1)
if new_ar < 1:
nh = int(scale*h)
nw = int(nh*new_ar)
else:
nw = int(scale*w)
nh = int(nw/new_ar)
image = image.resize((nw, nh), Image.BICUBIC)
#-----------------------------------------------#
# 将图片进行放置,分别对应四张分割图片的位置
#-----------------------------------------------#
if index == 0:
dx = int(w*min_offset_x) - nw
dy = int(h*min_offset_y) - nh
elif index == 1:
dx = int(w*min_offset_x) - nw
dy = int(h*min_offset_y)
elif index == 2:
dx = int(w*min_offset_x)
dy = int(h*min_offset_y)
elif index == 3:
dx = int(w*min_offset_x)
dy = int(h*min_offset_y) - nh
new_image = Image.new('RGB', (w,h), (128,128,128))
new_image.paste(image, (dx, dy))
image_data = np.array(new_image)
index = index + 1
box_data = []
#---------------------------------#
# 对box进行重新处理
#---------------------------------#
if len(box)>0:
np.random.shuffle(box)
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
box[:, 0:2][box[:, 0:2]<0] = 0
box[:, 2][box[:, 2]>w] = w
box[:, 3][box[:, 3]>h] = h
box_w = box[:, 2] - box[:, 0]
box_h = box[:, 3] - box[:, 1]
box = box[np.logical_and(box_w>1, box_h>1)]
box_data = np.zeros((len(box),5))
box_data[:len(box)] = box
image_datas.append(image_data)
box_datas.append(box_data)
#---------------------------------#
# 将图片分割,放在一起
#---------------------------------#
cutx = int(w * min_offset_x)
cuty = int(h * min_offset_y)
new_image = np.zeros([h, w, 3])
new_image[:cuty, :cutx, :] = image_datas[0][:cuty, :cutx, :]
new_image[cuty:, :cutx, :] = image_datas[1][cuty:, :cutx, :]
new_image[cuty:, cutx:, :] = image_datas[2][cuty:, cutx:, :]
new_image[:cuty, cutx:, :] = image_datas[3][:cuty, cutx:, :]
new_image = np.array(new_image, np.uint8)
#---------------------------------#
# 对图像进行色域变换
# 计算色域变换的参数
#---------------------------------#
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
#---------------------------------#
# 将图像转到HSV上
#---------------------------------#
hue, sat, val = cv2.split(cv2.cvtColor(new_image, cv2.COLOR_RGB2HSV))
dtype = new_image.dtype
#---------------------------------#
# 应用变换
#---------------------------------#
x = np.arange(0, 256, dtype=r.dtype)
lut_hue = ((x * r[0]) % 180).astype(dtype)
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
new_image = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
new_image = cv2.cvtColor(new_image, cv2.COLOR_HSV2RGB)
#---------------------------------#
# 对框进行进一步的处理
#---------------------------------#
new_boxes = self.merge_bboxes(box_datas, cutx, cuty)
return new_image, new_boxes
def get_random_data_with_MixUp(self, image_1, box_1, image_2, box_2):
new_image = np.array(image_1, np.float32) * 0.5 + np.array(image_2, np.float32) * 0.5
if len(box_1) == 0:
new_boxes = box_2
elif len(box_2) == 0:
new_boxes = box_1
else:
new_boxes = np.concatenate([box_1, box_2], axis=0)
return new_image, new_boxes
# DataLoader中collate_fn使用
def yolo_dataset_collate(batch):
images = []
bboxes = []
for img, box in batch:
images.append(img)
bboxes.append(box)
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
return images, bboxes
+64
View File
@@ -0,0 +1,64 @@
import numpy as np
from PIL import Image
#---------------------------------------------------------#
# 将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
else:
image = image.convert('RGB')
return image
#---------------------------------------------------#
# 对输入图像进行resize
#---------------------------------------------------#
def resize_image(image, size, letterbox_image):
iw, ih = image.size
w, h = size
if letterbox_image:
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
else:
new_image = image.resize((w, h), Image.BICUBIC)
return new_image
#---------------------------------------------------#
# 获得类
#---------------------------------------------------#
def get_classes(classes_path):
with open(classes_path, encoding='utf-8') as f:
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
return class_names, len(class_names)
def preprocess_input(image):
image /= 255.0
image -= np.array([0.485, 0.456, 0.406])
image /= np.array([0.229, 0.224, 0.225])
return image
#---------------------------------------------------#
# 获得学习率
#---------------------------------------------------#
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def show_config(**kwargs):
print('Configurations:')
print('-' * 70)
print('|%25s | %40s|' % ('keys', 'values'))
print('-' * 70)
for key, value in kwargs.items():
print('|%25s | %40s|' % (str(key), str(value)))
print('-' * 70)

Some files were not shown because too many files have changed in this diff Show More