init
@@ -0,0 +1,206 @@
|
||||
.DS_Store
|
||||
node_modules
|
||||
*.log
|
||||
explorations
|
||||
TODOs.md
|
||||
RELEASE_NOTE*.md
|
||||
packages/server-renderer/basic.js
|
||||
packages/server-renderer/build.dev.js
|
||||
packages/server-renderer/build.prod.js
|
||||
packages/server-renderer/server-plugin.js
|
||||
packages/server-renderer/client-plugin.js
|
||||
packages/template-compiler/build.js
|
||||
packages/template-compiler/browser.js
|
||||
.vscode
|
||||
dist
|
||||
temp
|
||||
types/v3-generated.d.ts
|
||||
|
||||
|
||||
# Compiled class file
|
||||
*.class
|
||||
|
||||
# Log file
|
||||
*.log
|
||||
|
||||
# BlueJ files
|
||||
*.ctxt
|
||||
|
||||
# Mobile Tools for Java (J2ME)
|
||||
.mtj.tmp/
|
||||
|
||||
# Package Files #
|
||||
*.jar
|
||||
*.war
|
||||
*.nar
|
||||
*.ear
|
||||
*.zip
|
||||
*.tar.gz
|
||||
*.rar
|
||||
|
||||
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
|
||||
hs_err_pid*
|
||||
replay_pid*
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
@@ -0,0 +1,33 @@
|
||||
HELP.md
|
||||
target/
|
||||
!.mvn/wrapper/maven-wrapper.jar
|
||||
!**/src/main/**/target/
|
||||
!**/src/test/**/target/
|
||||
|
||||
### STS ###
|
||||
.apt_generated
|
||||
.classpath
|
||||
.factorypath
|
||||
.project
|
||||
.settings
|
||||
.springBeans
|
||||
.sts4-cache
|
||||
|
||||
### IntelliJ IDEA ###
|
||||
.idea
|
||||
*.iws
|
||||
*.iml
|
||||
*.ipr
|
||||
|
||||
### NetBeans ###
|
||||
/nbproject/private/
|
||||
/nbbuild/
|
||||
/dist/
|
||||
/nbdist/
|
||||
/.nb-gradle/
|
||||
build/
|
||||
!**/src/main/**/build/
|
||||
!**/src/test/**/build/
|
||||
|
||||
### VS Code ###
|
||||
.vscode/
|
||||
@@ -0,0 +1,86 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-parent</artifactId>
|
||||
<version>2.6.11</version>
|
||||
<relativePath/> <!-- lookup parent from repository -->
|
||||
</parent>
|
||||
<groupId>com.common</groupId>
|
||||
<artifactId>backend2</artifactId>
|
||||
<version>0.0.1-SNAPSHOT</version>
|
||||
<name>backend2</name>
|
||||
<description>backend2</description>
|
||||
|
||||
<properties>
|
||||
<java.version>1.8</java.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-jdbc</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>mysql</groupId>
|
||||
<artifactId>mysql-connector-java</artifactId>
|
||||
<scope>runtime</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.baomidou</groupId>
|
||||
<artifactId>mybatis-plus-boot-starter</artifactId>
|
||||
<version>2.2.0</version>
|
||||
</dependency>
|
||||
<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-lang3 -->
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.9</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<optional>true</optional>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
</dependency>
|
||||
|
||||
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
<version>3.8.1</version>
|
||||
<configuration>
|
||||
<source>1.8</source>
|
||||
<target>1.8</target>
|
||||
<encoding>UTF-8</encoding>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-maven-plugin</artifactId>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
</project>
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.common.backend2;
|
||||
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
|
||||
@SpringBootApplication
|
||||
public class Backend2Application {
|
||||
|
||||
public static void main(String[] args) {
|
||||
SpringApplication.run(Backend2Application.class, args);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package com.common.backend2.config;
|
||||
|
||||
import org.springframework.web.servlet.HandlerInterceptor;
|
||||
import org.springframework.web.servlet.ModelAndView;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.util.List;
|
||||
|
||||
public class CommonInterceptor implements HandlerInterceptor {
|
||||
private List<String> excludedUrls;
|
||||
|
||||
public List<String> getExcludedUrls() {
|
||||
return excludedUrls;
|
||||
}
|
||||
|
||||
public void setExcludedUrls(List<String> excludedUrls) {
|
||||
this.excludedUrls = excludedUrls;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* 在业务处理器处理请求之前被调用 如果返回false
|
||||
* 从当前的拦截器往回执行所有拦截器的afterCompletion(),
|
||||
* 再退出拦截器链, 如果返回true 执行下一个拦截器,
|
||||
* 直到所有的拦截器都执行完毕 再执行被拦截的Controller
|
||||
* 然后进入拦截器链,
|
||||
* 从最后一个拦截器往回执行所有的postHandle()
|
||||
* 接着再从最后一个拦截器往回执行所有的afterCompletion()
|
||||
* @param request
|
||||
* @param response
|
||||
*/
|
||||
public boolean preHandle(HttpServletRequest request, HttpServletResponse response,
|
||||
Object handler) throws Exception {
|
||||
response.setHeader("Access-Control-Allow-Origin", "*");
|
||||
response.setHeader("Access-Control-Allow-Methods", "*");
|
||||
response.setHeader("Access-Control-Max-Age", "3600");
|
||||
response.setHeader("Access-Control-Allow-Headers",
|
||||
"Origin, X-Requested-With, Content-Type, Accept");
|
||||
return true;
|
||||
}
|
||||
|
||||
// 在业务处理器处理请求执行完成后,生成视图之前执行的动作
|
||||
public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler,
|
||||
ModelAndView modelAndView) throws Exception {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 在DispatcherServlet完全处理完请求后被调用
|
||||
* 当有拦截器抛出异常时,
|
||||
* 会从当前拦截器往回执行所有的拦截器的afterCompletion()
|
||||
* @param request
|
||||
* @param response
|
||||
* @param handler
|
||||
*
|
||||
*/
|
||||
public void afterCompletion(HttpServletRequest request, HttpServletResponse response,
|
||||
Object handler, Exception ex) throws Exception {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.common.backend2.config;
|
||||
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.http.client.ClientHttpRequestFactory;
|
||||
import org.springframework.http.client.SimpleClientHttpRequestFactory;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
@Configuration
|
||||
public class ResTemplateConfig {
|
||||
@Bean
|
||||
public RestTemplate restTemplate(ClientHttpRequestFactory factory) {
|
||||
return new RestTemplate(factory);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public ClientHttpRequestFactory simpleClientHttpRequestFactory() {
|
||||
SimpleClientHttpRequestFactory factory = new SimpleClientHttpRequestFactory();
|
||||
//超时设置
|
||||
factory.setReadTimeout(10000);//ms
|
||||
factory.setConnectTimeout(15000);//ms
|
||||
return factory;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
package com.common.backend2.config;
|
||||
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.web.servlet.config.annotation.CorsRegistry;
|
||||
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
|
||||
import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry;
|
||||
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
@Configuration
|
||||
public class WebMvcConfig implements WebMvcConfigurer {
|
||||
@Override
|
||||
public void addCorsMappings(CorsRegistry registry) {
|
||||
registry.addMapping("/**")
|
||||
.allowedHeaders("*")
|
||||
.allowedOriginPatterns("*")
|
||||
.allowedMethods("POST", "GET", "PUT", "OPTIONS", "DELETE")
|
||||
.maxAge(3600)
|
||||
.allowCredentials(true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addInterceptors(InterceptorRegistry registry) {
|
||||
WebMvcConfigurer.super.addInterceptors(registry);
|
||||
registry.addInterceptor(CommonInterceptor()).addPathPatterns("/**");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addResourceHandlers(ResourceHandlerRegistry registry) {
|
||||
/**
|
||||
* 资源映射路径
|
||||
* addResourceHandler:访问映射路径
|
||||
* addResourceLocations:资源绝对路径
|
||||
*/
|
||||
//String projectRootDirectoryPath = System.getProperty("user.dir");
|
||||
// 通过 File 对象的 getParent() 方法获取到根目录的上级目录
|
||||
//String parentPath = new File(projectRootDirectoryPath).getParent();
|
||||
|
||||
// registry.addResourceHandler("/result/**")
|
||||
// .addResourceLocations("file:///"+parentPath+"/RED-CNN-master(Lite)\\save/fig/");
|
||||
registry.addResourceHandler("/result/**")
|
||||
.addResourceLocations("file:///E:/Document/project/temp/save/");
|
||||
registry.addResourceHandler("/assets/**")
|
||||
.addResourceLocations("file:///E:/Document/project/temp/assets/");
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Bean
|
||||
public CommonInterceptor CommonInterceptor() {
|
||||
return new CommonInterceptor();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package com.common.backend2.controller;
|
||||
|
||||
|
||||
import com.common.backend2.entity.File;
|
||||
import com.common.backend2.response.Result;
|
||||
import com.common.backend2.response.ResponseCode;
|
||||
import com.common.backend2.service.FileService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.LinkedList;
|
||||
import java.util.Queue;
|
||||
|
||||
@CrossOrigin
|
||||
@RestController
|
||||
@RequestMapping("/api")
|
||||
public class FileController {
|
||||
|
||||
|
||||
@Autowired
|
||||
private FileService fileService;
|
||||
|
||||
|
||||
Queue<String> queue =new LinkedList(); //test_patient
|
||||
|
||||
@RequestMapping(value="/multi/uploadMultiImage",method=RequestMethod.POST)
|
||||
public Result uploadMultiImage(@RequestParam("files") MultipartFile[] files,HttpServletRequest request){
|
||||
//files 就是前端传来的多文件数组
|
||||
if (files.length<=0) {
|
||||
return new Result(ResponseCode.FILE_EMPTY.getCode(), ResponseCode.FILE_EMPTY.getMsg(), null);
|
||||
}
|
||||
for(MultipartFile file :files){
|
||||
fileService.upLoadFiles(file,request);
|
||||
}
|
||||
return new Result(ResponseCode.SUCCESS.getCode(), ResponseCode.SUCCESS.getMsg(), "数据上传成功");
|
||||
}
|
||||
|
||||
@RequestMapping(value = "/upload",method = RequestMethod.POST)
|
||||
public Result upLoadFile(@RequestParam("file") MultipartFile multipartFile,HttpServletRequest request) {
|
||||
if (multipartFile.isEmpty()) {
|
||||
return new Result(ResponseCode.FILE_EMPTY.getCode(), ResponseCode.FILE_EMPTY.getMsg(), null);
|
||||
}
|
||||
return fileService.upLoadFiles(multipartFile,request);
|
||||
|
||||
}
|
||||
@RequestMapping(value="/predict/{fileId}",method=RequestMethod.GET)
|
||||
public Result predictAndreturnResult (@PathVariable("fileId")Integer id){
|
||||
|
||||
return fileService.processFile(id);
|
||||
}
|
||||
@RequestMapping(value = "/download/{id}",method = RequestMethod.GET)
|
||||
public void downloadFiles(@PathVariable("id") String id, HttpServletRequest request, HttpServletResponse response){
|
||||
OutputStream outputStream=null;
|
||||
InputStream inputStream=null;
|
||||
BufferedInputStream bufferedInputStream=null;
|
||||
byte[] bytes=new byte[1024];
|
||||
File file = fileService.getFileById(id);
|
||||
String fileName = file.getFileName();
|
||||
// 获取输出流
|
||||
try {
|
||||
response.setHeader("Content-Disposition", "attachment;filename=" + new String(fileName.getBytes(StandardCharsets.UTF_8), StandardCharsets.ISO_8859_1));
|
||||
response.setContentType("application/force-download");
|
||||
inputStream=fileService.getFileInputStream(file);
|
||||
bufferedInputStream=new BufferedInputStream(inputStream);
|
||||
outputStream = response.getOutputStream();
|
||||
int i=bufferedInputStream.read(bytes);
|
||||
while (i!=-1){
|
||||
outputStream.write(bytes,0,i);
|
||||
i=bufferedInputStream.read(bytes);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}finally {
|
||||
try {
|
||||
if (inputStream!=null){
|
||||
inputStream.close();
|
||||
}
|
||||
if (outputStream!=null){
|
||||
outputStream.close();
|
||||
}
|
||||
if (bufferedInputStream!=null){
|
||||
bufferedInputStream.close();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@RequestMapping(value="/getResult",method=RequestMethod.GET)
|
||||
public Result getUrlsByIp(HttpServletRequest request,HttpServletResponse response){
|
||||
return fileService.returnUrls(request,response);
|
||||
}
|
||||
@RequestMapping(value="/delete/{id}",method=RequestMethod.DELETE)
|
||||
public Result deleteOriginAndProcess(@PathVariable("id")Integer id ){
|
||||
return fileService.deleteOne(id);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package com.common.backend2.entity;
|
||||
|
||||
import com.baomidou.mybatisplus.annotations.TableId;
|
||||
import com.baomidou.mybatisplus.annotations.TableName;
|
||||
import com.baomidou.mybatisplus.enums.IdType;
|
||||
import lombok.*;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@Getter
|
||||
@Setter
|
||||
@EqualsAndHashCode
|
||||
@TableName(value = "file")
|
||||
public class File implements Serializable {
|
||||
|
||||
private static final long serialVersionUID=1L;
|
||||
|
||||
/**
|
||||
* 文件id
|
||||
*/
|
||||
@TableId(value="id",type = IdType.AUTO)
|
||||
private Integer fileId;
|
||||
/**
|
||||
* 文件存储路径
|
||||
*/
|
||||
private String filePath;
|
||||
/**
|
||||
* 文件存储路径
|
||||
*/
|
||||
private String savePath;
|
||||
/**
|
||||
* 文件名称
|
||||
*/
|
||||
private String fileName;
|
||||
/**
|
||||
*
|
||||
* 是否处理过
|
||||
*
|
||||
*/
|
||||
private Integer processed;
|
||||
|
||||
private String result;
|
||||
|
||||
|
||||
public File(Object o, String saveChildPath, String resultChildPath, String fileName) {
|
||||
this.fileId=null;
|
||||
this.filePath= saveChildPath;
|
||||
this.savePath = resultChildPath;
|
||||
this.fileName = fileName;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.common.backend2.mapper;
|
||||
|
||||
|
||||
import com.baomidou.mybatisplus.mapper.BaseMapper;
|
||||
|
||||
import com.common.backend2.entity.File;
|
||||
import org.apache.ibatis.annotations.Insert;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
import org.apache.ibatis.annotations.Select;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Mapper
|
||||
@Repository
|
||||
public interface FileMapper extends BaseMapper<File> {
|
||||
@Select("select * from file")
|
||||
List<File>findAll();
|
||||
@Select("SELECT LAST_INSERT_ID();")
|
||||
Integer getId();
|
||||
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package com.common.backend2.response;
|
||||
|
||||
|
||||
public enum ResponseCode {
|
||||
// 系统模块
|
||||
SUCCESS(0, "操作成功"),
|
||||
ERROR(1, "操作失败"),
|
||||
SERVER_ERROR(500, "服务器异常"),
|
||||
|
||||
// 通用模块 1xxxx
|
||||
ILLEGAL_ARGUMENT(10000, "参数不合法"),
|
||||
REPETITIVE_OPERATION(10001, "请勿重复操作"),
|
||||
ACCESS_LIMIT(10002, "请求太频繁, 请稍后再试"),
|
||||
MAIL_SEND_SUCCESS(10003, "邮件发送成功"),
|
||||
|
||||
// 用户模块 2xxxx
|
||||
NEED_LOGIN(20001, "登录失效"),
|
||||
USERNAME_OR_PASSWORD_EMPTY(20002, "用户名或密码不能为空"),
|
||||
USERNAME_OR_PASSWORD_WRONG(20003, "用户名或密码错误"),
|
||||
USER_NOT_EXISTS(20004, "用户不存在"),
|
||||
WRONG_PASSWORD(20005, "密码错误"),
|
||||
|
||||
// 文件模块 3xxxx
|
||||
FILE_EMPTY(30001,"文件不能空"),
|
||||
FILE_NAME_EMPTY(30002,"文件名称不能为空"),
|
||||
FILE_MAX_SIZE(30003,"文件大小超出"),
|
||||
;
|
||||
|
||||
ResponseCode(Integer code, String msg) {
|
||||
this.code = code;
|
||||
this.msg = msg;
|
||||
}
|
||||
|
||||
private Integer code;
|
||||
|
||||
private String msg;
|
||||
|
||||
public Integer getCode() {
|
||||
return code;
|
||||
}
|
||||
|
||||
public void setCode(Integer code) {
|
||||
this.code = code;
|
||||
}
|
||||
|
||||
public String getMsg() {
|
||||
return msg;
|
||||
}
|
||||
|
||||
public void setMsg(String msg) {
|
||||
this.msg = msg;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.common.backend2.response;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class Result {
|
||||
private int code;
|
||||
private String message;
|
||||
private Object data;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
package com.common.backend2.service;
|
||||
|
||||
|
||||
|
||||
import com.baomidou.mybatisplus.service.IService;
|
||||
import com.common.backend2.entity.File;
|
||||
import com.common.backend2.response.Result;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.io.InputStream;
|
||||
|
||||
|
||||
public interface FileService extends IService<File> {
|
||||
|
||||
/**
|
||||
* 文件上传接口
|
||||
* @param file
|
||||
* @return
|
||||
*/
|
||||
Result upLoadFiles(MultipartFile file, HttpServletRequest request);
|
||||
|
||||
/**
|
||||
* 根据id获取文件
|
||||
* @param id
|
||||
* @return
|
||||
*/
|
||||
File getFileById(String id);
|
||||
|
||||
/**
|
||||
* 根据id获取数据流
|
||||
* @param file
|
||||
* @return
|
||||
*/
|
||||
InputStream getFileInputStream(File file);
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* 对文件进行处理
|
||||
* @param
|
||||
* @return: 返回处理完的文件的id
|
||||
*/
|
||||
Result processFile(Integer id);
|
||||
|
||||
|
||||
Result returnUrls(HttpServletRequest request, HttpServletResponse response);
|
||||
|
||||
Result deleteOne(Integer id);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package com.common.backend2.service.Impl;
|
||||
|
||||
|
||||
import com.baomidou.mybatisplus.service.impl.ServiceImpl;
|
||||
|
||||
import com.common.backend2.entity.File;
|
||||
import com.common.backend2.mapper.FileMapper;
|
||||
import com.common.backend2.response.ResponseCode;
|
||||
import com.common.backend2.response.Result;
|
||||
|
||||
import com.common.backend2.service.FileService;
|
||||
import com.common.backend2.utils.IpUtil;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.io.*;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Service
|
||||
public class FileServiceImpl extends ServiceImpl<FileMapper, File> implements FileService {
|
||||
|
||||
|
||||
@Value("${file.data-path}")
|
||||
private String dataPath;
|
||||
// @Value("${file.parent-path}")
|
||||
// private String parentPath;
|
||||
@Value("${file.result-path}")
|
||||
private String resultPath;
|
||||
@Value("${url.get-result-url}")
|
||||
private String getResultUrl;
|
||||
@Autowired
|
||||
private FileMapper fileMapper;
|
||||
|
||||
@Autowired
|
||||
private RestTemplateMethods restTemplateMethods;
|
||||
|
||||
@Transactional
|
||||
public Result upLoadFiles(MultipartFile file, HttpServletRequest request) {
|
||||
long MAX_SIZE = 2097152L;
|
||||
|
||||
String fileName = file.getOriginalFilename(); //获取原名
|
||||
if (StringUtils.isEmpty(fileName)) {
|
||||
return new Result(ResponseCode.FILE_NAME_EMPTY.getCode(), ResponseCode.FILE_NAME_EMPTY.getMsg(), null);
|
||||
} // 判断不为空
|
||||
if (file.getSize() > MAX_SIZE) {
|
||||
return new Result(ResponseCode.FILE_MAX_SIZE.getCode(), ResponseCode.FILE_MAX_SIZE.getMsg(), null);
|
||||
} // 判断大小是否超出限制
|
||||
// String suffixName = fileName.contains(".") ? fileName.substring(fileName.lastIndexOf(".")) : null; //获取没有后缀的文件名
|
||||
// 获取当前时间戳作为文件名
|
||||
Date date = new Date();
|
||||
SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd-HH:mm:ss");
|
||||
String dateString = formatter.format(date);
|
||||
|
||||
String saveChildPath = dataPath ;//+ "/" + dateString; // 使用时间戳字符串存储
|
||||
|
||||
String resultChildPath = resultPath ;//+ "/" + dateString;
|
||||
java.io.File newFile = new java.io.File(saveChildPath, fileName); //新建一个文件进行存储
|
||||
if (!newFile.getParentFile().exists()) {
|
||||
newFile.getParentFile().mkdirs();
|
||||
} //文件创建
|
||||
try {
|
||||
//文件写入--文件的转存
|
||||
file.transferTo(newFile);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
} // 转存到文件中
|
||||
File files = new File(null, saveChildPath, resultChildPath, fileName); // 信息包装成对象准备存入数据库
|
||||
fileMapper.insert(files); //插入到数据库
|
||||
int fildId = fileMapper.getId();
|
||||
|
||||
return new Result(ResponseCode.SUCCESS.getCode(), ResponseCode.SUCCESS.getMsg(), fildId);
|
||||
}
|
||||
|
||||
public File getFileById(String id) {
|
||||
return fileMapper.selectById(id);
|
||||
|
||||
}
|
||||
|
||||
public InputStream getFileInputStream(File files) {
|
||||
java.io.File file = new java.io.File(files.getFilePath());
|
||||
try {
|
||||
return new FileInputStream(file);
|
||||
} catch (FileNotFoundException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
public Result processFile(Integer id) {
|
||||
//
|
||||
// File file = fileMapper.selectById(id);
|
||||
// String fileName = file.getFileName();
|
||||
// String classAndScore=restTemplateMethods.RestTemplatePost(fileName);
|
||||
// String resultUrl=getResultUrl+"/result/"+fileName;
|
||||
// ArrayList result =new ArrayList<>();
|
||||
// result.add(resultUrl);
|
||||
// result.add(classAndScore);
|
||||
// //从结果文件夹发回--这个可以用一个get来请求或者使用轮询
|
||||
// return new Result(ResponseCode.SUCCESS.getCode(), ResponseCode.SUCCESS.getMsg(), result);
|
||||
|
||||
File file = fileMapper.selectById(id);
|
||||
Integer processed = file.getProcessed();
|
||||
String fileName = file.getFileName();
|
||||
ArrayList result =new ArrayList<>();
|
||||
String resultUrl=getResultUrl+"/result/"+fileName;
|
||||
result.add(resultUrl);
|
||||
if(processed!=0){
|
||||
result.add(file.getResult());
|
||||
}
|
||||
else{
|
||||
String classAndScore=restTemplateMethods.RestTemplatePost(fileName);
|
||||
//将结果字符串保存至数据库
|
||||
file.setResult(classAndScore);
|
||||
file.setProcessed(1);
|
||||
fileMapper.updateById(file);
|
||||
|
||||
result.add(classAndScore);
|
||||
|
||||
//从结果文件夹发回--这个可以用一个get来请求或者使用轮询
|
||||
|
||||
}
|
||||
return new Result(ResponseCode.SUCCESS.getCode(), ResponseCode.SUCCESS.getMsg(), result);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Result returnUrls(HttpServletRequest request, HttpServletResponse response) {
|
||||
|
||||
String ip = IpUtil.getClientIpAddr(request);
|
||||
String path = resultPath + "/" + ip; //python 图片目录
|
||||
String RetPath = getResultUrl + "/result/" + ip + "/";
|
||||
java.io.File file = new java.io.File(path);
|
||||
java.io.File[] files = file.listFiles(); //获取所有的文件名
|
||||
List<String> ResultUrls = new ArrayList<>();
|
||||
for (java.io.File file1 : files) {
|
||||
ResultUrls.add(RetPath + file1.getName());
|
||||
}
|
||||
return new Result(ResponseCode.SUCCESS.getCode(), ResponseCode.SUCCESS.getMsg(), ResultUrls);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Result deleteOne(Integer id) {
|
||||
File file =fileMapper.selectById(id);
|
||||
// 获取路径
|
||||
String originPath = file.getFilePath();
|
||||
String resultPath = file.getSavePath();
|
||||
String fileName = file.getFileName();
|
||||
String originFullPath = originPath+"/"+fileName;
|
||||
String resultFullPath = resultPath+"/"+fileName;
|
||||
|
||||
//删除文件
|
||||
java.io.File originFile = new java.io.File(originFullPath);
|
||||
java.io.File resultFile = new java.io.File(resultFullPath);
|
||||
|
||||
// 检查文件是否存在
|
||||
if (originFile.exists()) {
|
||||
// 尝试删除文件
|
||||
boolean deleted = originFile.delete();
|
||||
|
||||
if (deleted) {
|
||||
System.out.println("Origin文件已成功删除");
|
||||
} else {
|
||||
System.out.println("无法删除Origin文件");
|
||||
return new Result(ResponseCode.ERROR.getCode(), ResponseCode.ERROR.getMsg(), "Origin删除失败");
|
||||
}
|
||||
} else {
|
||||
System.out.println("Origin文件不存在");
|
||||
return new Result(ResponseCode.ERROR.getCode(), ResponseCode.ERROR.getMsg(), "Origin文件不存在");
|
||||
}
|
||||
|
||||
if (resultFile.exists()) {
|
||||
// 尝试删除文件
|
||||
boolean deleted = resultFile.delete();
|
||||
|
||||
if (deleted) {
|
||||
System.out.println("Result文件已成功删除");
|
||||
} else {
|
||||
System.out.println("无法删除Result文件");
|
||||
return new Result(ResponseCode.ERROR.getCode(), ResponseCode.ERROR.getMsg(), "Result删除失败");
|
||||
}
|
||||
} else {
|
||||
System.out.println("Result文件不存在");
|
||||
}
|
||||
//删除数据库数据
|
||||
int result = fileMapper.deleteById(file);
|
||||
|
||||
if (result > 0) {
|
||||
// 删除成功
|
||||
return new Result(ResponseCode.SUCCESS.getCode(), ResponseCode.SUCCESS.getMsg(), "删除成功");
|
||||
} else {
|
||||
// 删除失败
|
||||
return new Result(ResponseCode.ERROR.getCode(), ResponseCode.ERROR.getMsg(), "数据库删除失败");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package com.common.backend2.service.Impl;
|
||||
|
||||
|
||||
import jdk.nashorn.internal.parser.JSONParser;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.HttpURLConnection;
|
||||
import java.net.MalformedURLException;
|
||||
import java.net.URL;
|
||||
|
||||
@Service
|
||||
public class RestTemplateMethods {
|
||||
@Autowired
|
||||
private RestTemplate restTemplate;
|
||||
@Value("${url.predict-url}")
|
||||
String predictUrl ;
|
||||
|
||||
public String RestTemplatePost(String fileName) {
|
||||
|
||||
|
||||
|
||||
LinkedMultiValueMap<String, String> request = new LinkedMultiValueMap<>();
|
||||
//request.set("className",dateString);
|
||||
request.set("fileName", fileName);
|
||||
try {
|
||||
URL url1 = new URL(predictUrl);
|
||||
HttpURLConnection con =(HttpURLConnection)url1.openConnection();
|
||||
con.setRequestMethod("POST");
|
||||
con.setDoOutput(true);
|
||||
con.setDoInput(true);
|
||||
con.setUseCaches(false);
|
||||
con.setRequestProperty("Content-Type","application/from-data");
|
||||
} catch (MalformedURLException e) {
|
||||
e.printStackTrace();
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
//请求
|
||||
String result = restTemplate.postForObject(predictUrl, request, String.class);
|
||||
|
||||
System.out.println(result);
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package com.common.backend2.utils;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import java.net.InetAddress;
|
||||
import java.net.UnknownHostException;
|
||||
|
||||
|
||||
public class IpUtil {
|
||||
/**
|
||||
* @description: 获取请求的ip
|
||||
* @param request
|
||||
* @return: java.lang.String
|
||||
*/
|
||||
public static String getClientIpAddr(HttpServletRequest request) {
|
||||
String ip = request.getHeader("X-Forwarded-For");
|
||||
if (ip != null && ip.length() != 0 && !"unknown".equalsIgnoreCase(ip)) {
|
||||
// 多次反向代理后会有多个ip值,第一个ip才是真实ip
|
||||
int index = ip.indexOf(',');
|
||||
if (index != -1) {
|
||||
ip = ip.substring(0, index);
|
||||
}
|
||||
}
|
||||
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
|
||||
ip = request.getHeader("Proxy-Client-IP");
|
||||
}
|
||||
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
|
||||
ip = request.getHeader("WL-Proxy-Client-IP");
|
||||
}
|
||||
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
|
||||
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
|
||||
}
|
||||
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
|
||||
ip = request.getHeader("HTTP_X_FORWARDED");
|
||||
}
|
||||
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
|
||||
ip = request.getHeader("HTTP_X_CLUSTER_CLIENT_IP");
|
||||
}
|
||||
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
|
||||
ip = request.getHeader("HTTP_CLIENT_IP");
|
||||
}
|
||||
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
|
||||
ip = request.getHeader("HTTP_FORWARDED_FOR");
|
||||
}
|
||||
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
|
||||
ip = request.getHeader("HTTP_FORWARDED");
|
||||
}
|
||||
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
|
||||
ip = request.getHeader("HTTP_VIA");
|
||||
}
|
||||
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
|
||||
ip = request.getHeader("REMOTE_ADDR");
|
||||
}
|
||||
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
|
||||
ip = request.getHeader("X-Real-IP");
|
||||
}
|
||||
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
|
||||
ip = request.getRemoteAddr();
|
||||
if (ip.equals("127.0.0.1") || ip.equals("0:0:0:0:0:0:0:1")) {
|
||||
//根据网卡取本机配置的IP
|
||||
InetAddress inet = null;
|
||||
try {
|
||||
inet = InetAddress.getLocalHost();
|
||||
} catch (UnknownHostException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
ip = inet.getHostAddress();
|
||||
}
|
||||
}
|
||||
return ip;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
server:
|
||||
port: 5000
|
||||
#port: 80
|
||||
servlet:
|
||||
encoding:
|
||||
charset: utf-8
|
||||
enabled: true
|
||||
force: true
|
||||
spring:
|
||||
datasource:
|
||||
driver-class-name: com.mysql.cj.jdbc.Driver
|
||||
name: defaultDataSource
|
||||
# url: jdbc:mysql://175.178.186.100:31441/DiatomRecognition?useUnicode=true&characterEncoding=UTF-8&serverTimezone=UTC
|
||||
url: jdbc:mysql://192.168.161.130:3306/DiatomRecognition?useUnicode=true&characterEncoding=UTF-8&serverTimezone=UTC
|
||||
username: root
|
||||
# password: xiyi409@
|
||||
password: 123456
|
||||
|
||||
mybatis-plus:
|
||||
typeAliasesPackage: com.common.backend2.entity
|
||||
mapperLocations: classpath:mapper/*.xml
|
||||
file: # 服务器
|
||||
data-path: /home/hejinwen/detect/data/upload
|
||||
result-path: /home/hejinwen/detect/data/save
|
||||
url: # 服务器
|
||||
predict-url: http://127.0.0.1:2000/predict
|
||||
get-result-url: http://192.168.161.130:80 # 走nginx的代理
|
||||
#url: # 服务器测试
|
||||
# predict-url: http://175.178.186.100:5000/predict
|
||||
# get-result-url: wadouri:http://175.178.186.100:6000
|
||||
#file: # 本地测试
|
||||
# parent-path: E:\Document\project\temp
|
||||
# data-path: E:/Document/project/temp/data
|
||||
# result-path: E:\Document\project\temp/save
|
||||
#url: # 本地测试
|
||||
# predict-url: http://127.0.0.1:2000/predict
|
||||
# get-result-url: http://127.0.0.1:5000
|
||||
logging:
|
||||
level:
|
||||
root: INFO
|
||||
web: DEBUG
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.common.backend2;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
|
||||
@SpringBootTest
|
||||
class Backend2ApplicationTests {
|
||||
|
||||
@Test
|
||||
void contextLoads() {
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
module.exports = {
|
||||
presets: [
|
||||
'@vue/cli-plugin-babel/preset'
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "es5",
|
||||
"module": "esnext",
|
||||
"baseUrl": "./",
|
||||
"moduleResolution": "node",
|
||||
"paths": {
|
||||
"@/*": [
|
||||
"src/*"
|
||||
]
|
||||
},
|
||||
"lib": [
|
||||
"esnext",
|
||||
"dom",
|
||||
"dom.iterable",
|
||||
"scripthost"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
export default [
|
||||
// 获取天气情况
|
||||
{
|
||||
method: 'get',
|
||||
url: '/getWeather',
|
||||
data: {
|
||||
code: 200,
|
||||
message: 'success',
|
||||
data: {
|
||||
weather: '小雨转多云',
|
||||
temperature: '13℃~18℃'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
const Mock = require('mockjs')
|
||||
|
||||
// 遍历所有mock文件
|
||||
const files = require.context('.', true, /\.js$/)
|
||||
let mockList = files.keys()
|
||||
.filter(key =>
|
||||
key !== './index.js' && files(key).default
|
||||
)
|
||||
.map(key =>
|
||||
files(key).default
|
||||
);
|
||||
|
||||
// 开始注册所有mock服务
|
||||
for (let list of mockList) { //遍历所有模块
|
||||
// 遍历模块中的所有api
|
||||
for (let item of list) {
|
||||
// 注入mock
|
||||
Mock.mock(
|
||||
'/api' + item.url,
|
||||
item.data
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
{
|
||||
"name": "bigscreen",
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"serve": "vue-cli-service serve",
|
||||
"build": "vue-cli-service build",
|
||||
"lint": "vue-cli-service lint"
|
||||
},
|
||||
"dependencies": {
|
||||
"axios": "^1.4.0",
|
||||
"core-js": "^3.8.3",
|
||||
"dayjs": "^1.11.8",
|
||||
"echarts": "^5.4.2",
|
||||
"element-ui": "^2.15.13",
|
||||
"mockjs": "^1.1.0",
|
||||
"vue": "^2.6.14",
|
||||
"vue-baidu-map": "^0.21.22"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@babel/core": "^7.12.16",
|
||||
"@babel/eslint-parser": "^7.12.16",
|
||||
"@vue/cli-plugin-babel": "~5.0.0",
|
||||
"@vue/cli-plugin-eslint": "~5.0.0",
|
||||
"@vue/cli-service": "~5.0.0",
|
||||
"eslint": "^7.32.0",
|
||||
"eslint-plugin-vue": "^8.0.3",
|
||||
"less": "^4.1.3",
|
||||
"less-loader": "^11.1.3",
|
||||
"vue-template-compiler": "^2.6.14"
|
||||
},
|
||||
"eslintConfig": {
|
||||
"root": true,
|
||||
"env": {
|
||||
"node": true
|
||||
},
|
||||
"extends": [
|
||||
"plugin:vue/essential",
|
||||
"eslint:recommended"
|
||||
],
|
||||
"parserOptions": {
|
||||
"parser": "@babel/eslint-parser"
|
||||
},
|
||||
"rules": {}
|
||||
},
|
||||
"browserslist": [
|
||||
"> 1%",
|
||||
"last 2 versions",
|
||||
"not dead"
|
||||
]
|
||||
}
|
||||
|
After Width: | Height: | Size: 4.2 KiB |
@@ -0,0 +1,38 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="">
|
||||
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1.0">
|
||||
<link rel="icon" href="<%= BASE_URL %>favicon.ico">
|
||||
<title>
|
||||
<%= htmlWebpackPlugin.options.title %>
|
||||
</title>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
html,
|
||||
body {
|
||||
height: 100%;
|
||||
}
|
||||
</style>
|
||||
<script>
|
||||
//fontsize计算 1rem=16px
|
||||
document.documentElement.style.fontSize = document.documentElement.clientWidth / 1920 * 16 + 'px'
|
||||
</script>
|
||||
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<noscript>
|
||||
<strong>We're sorry but <%= htmlWebpackPlugin.options.title %> doesn't work properly without JavaScript enabled.
|
||||
Please enable it to continue.</strong>
|
||||
</noscript>
|
||||
<div id="app"></div>
|
||||
<!-- built files will be auto injected -->
|
||||
</body>
|
||||
|
||||
</html>
|
||||
@@ -0,0 +1,589 @@
|
||||
<template>
|
||||
<div id="app">
|
||||
<div class="container">
|
||||
<div class="content-top">
|
||||
<span class="title">硅藻识别系统</span>
|
||||
<div class="time-box">
|
||||
<div id="time">{{ nowTime }}</div>
|
||||
<div class="line-box line-box1"></div>
|
||||
</div>
|
||||
<div class="weather-box">
|
||||
<div class="line-box line-box2"></div>
|
||||
<div id="weather">
|
||||
{{ weatherObj.weather }} {{
|
||||
weatherObj.temperature
|
||||
}}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="content-bottom">
|
||||
<div class="content-left">
|
||||
<section class="section-box roadDistribute-box">
|
||||
<!-- <div class="content-bodyhead">
|
||||
<div style="color: rgb(179, 179, 200); font-size: 18px;text-align: center;padding: 10px;">诊断信息输入</div>
|
||||
死者姓名:<input type="text" placeholder="在这里输入文本">
|
||||
<br>性别:<input type="text" placeholder="在这里输入文本">
|
||||
<br>法医鉴定:<input type="text" placeholder="在这里输入文本">
|
||||
<br>送检材料:<input type="text" placeholder="在这里输入文本">
|
||||
<br><label>
|
||||
<input type="checkbox" name="option1" value="Option 1"> 小环藻属
|
||||
</label>
|
||||
|
||||
<label>
|
||||
<input type="checkbox" name="option2" value="Option 2"> 直链藻属
|
||||
</label>
|
||||
|
||||
<label>
|
||||
<input type="checkbox" name="option3" value="Option 3"> 菱形藻属
|
||||
</label>
|
||||
</div> -->
|
||||
<div class="title center" >源文件上传</div>
|
||||
<div class="content-body">
|
||||
<div class="distribute-textBox">
|
||||
<div style="display: flex;justify-content: center; align-items: center; ">
|
||||
<!-- <img src="./assets/image/gdut.jpg" alt="" > -->
|
||||
<div class="imagebox" >
|
||||
<el-upload
|
||||
class="avatar-uploader"
|
||||
action= http://192.168.161.130:80/api/upload
|
||||
:show-file-list="false"
|
||||
:on-success="handleAvatarSuccess">
|
||||
|
||||
<img v-if="imageUrl" :src="imageUrl" class="avatar">
|
||||
|
||||
<i v-else class="el-icon-plus avatar-uploader-icon"></i>
|
||||
|
||||
<el-button size="small" class="btn" type="btn-primary">点击上传</el-button>
|
||||
</el-upload>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="table">
|
||||
<div class="title center" >检测结果</div>
|
||||
<thead>
|
||||
<tr>
|
||||
<td align="center" style="width:200px;">类别</td>
|
||||
<td align="center" style="width:200px;">概率</td>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="item in tableData" :key="item">
|
||||
<td align="center">{{item.class}}</td>
|
||||
<td align="center">{{item.score}}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
<div class="content-middle">
|
||||
<div class="image-box">
|
||||
<img :src="resultUrl" id="100"
|
||||
alt="">
|
||||
</div>
|
||||
<div class="btnArea">
|
||||
|
||||
<button type="btn-primary"
|
||||
@click="deleteImg"
|
||||
class="btn">删除</button>
|
||||
<button type="btn-primary"
|
||||
@click="getResult"
|
||||
class="btn">获取结果</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="content-right">
|
||||
<!-- <div class="section-box eventStatistics-box">
|
||||
<img src="./assets/image/logo.png"
|
||||
alt="">
|
||||
</div> -->
|
||||
<div class="title">我国主要水藻图例</div>
|
||||
<div class="content">
|
||||
<div class="row">
|
||||
<div class="column">
|
||||
<img src="./assets/image/a.png" alt="">
|
||||
<br>小环藻属
|
||||
</div>
|
||||
<div class="column">
|
||||
<img src="./assets/image/b.png" alt="">
|
||||
<br>直链藻属
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="column">
|
||||
<img src="./assets/image/c.png" alt="">
|
||||
<br>菱形藻属
|
||||
</div>
|
||||
<div class="column">
|
||||
<img src="./assets/image/d.png" alt="">
|
||||
<br>卵形藻属
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="column">
|
||||
<img src="./assets/image/e.png" alt="">
|
||||
<br>桥弯藻属
|
||||
</div>
|
||||
<div class="column">
|
||||
<img src="./assets/image/f.png" alt="">
|
||||
<br>异形藻属
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="column">
|
||||
<img src="./assets/image/g.png" alt="">
|
||||
<br>针杆藻属
|
||||
</div>
|
||||
<div class="column">
|
||||
<img src="./assets/image/h.png" alt="">
|
||||
<br>舟形藻属
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import dayjs from "dayjs";
|
||||
|
||||
|
||||
import {
|
||||
getWeather
|
||||
} from "@/api/bigScreen-api";
|
||||
import nodata from"./assets/image/no-data.png";
|
||||
import axios from 'axios'
|
||||
|
||||
export default {
|
||||
name: "App",
|
||||
components: {},
|
||||
data () {
|
||||
return {
|
||||
timer: null,
|
||||
// baseUrL:'http://192.168.161.130/',
|
||||
baseUrL:'http://192.168.161.130:80/',
|
||||
nowTime: dayjs(new Date()).format("YYYY-MM-DD HH:mm"),
|
||||
weatherObj: {},
|
||||
resultClass:"",
|
||||
resultScore:"",
|
||||
imageUrl: 'http://192.168.161.130:80/assets/gdut.jpg',
|
||||
processIdNow:1,
|
||||
resultUrl:nodata,
|
||||
// tableData: [{
|
||||
// class :'异极藻(Gomphonema)',
|
||||
// score :'0.8269811',
|
||||
// },
|
||||
// {
|
||||
// class :'异极藻(Gomphonema)',
|
||||
// score :'0.8269811',
|
||||
// },
|
||||
// {
|
||||
// class :'舟形藻(Navicula)',
|
||||
// score :'0.54520714',
|
||||
// },
|
||||
// ]
|
||||
tableData:[]
|
||||
};
|
||||
},
|
||||
mounted () {
|
||||
this.init();
|
||||
},
|
||||
methods: {
|
||||
init () {
|
||||
this.getNowTime(); // 获取当前时间
|
||||
this.getWeather(); // 获取天气
|
||||
},
|
||||
// 获取当前时间
|
||||
getNowTime () {
|
||||
clearInterval(this.timer);
|
||||
this.timer = setInterval(() => {
|
||||
this.nowTime = dayjs(new Date()).format("YYYY-MM-DD HH:mm");
|
||||
}, 1000 * 30);
|
||||
},
|
||||
// 获取天气
|
||||
getWeather () {
|
||||
getWeather().then((res) => {
|
||||
if (res.code == 200) {
|
||||
this.weatherObj = res.data;
|
||||
}
|
||||
});
|
||||
},
|
||||
handleAvatarSuccess(res, file) {
|
||||
this.imageUrl = URL.createObjectURL(file.raw);
|
||||
//获取图片id
|
||||
this.processIdNow = res.data;
|
||||
},
|
||||
|
||||
deleteImg () {
|
||||
|
||||
this.$confirm('此操作将永久删除该文件, 是否继续?', '提示', {
|
||||
confirmButtonText: '确定',
|
||||
cancelButtonText: '取消',
|
||||
type: 'warning'
|
||||
}).then(() => {
|
||||
axios.delete(this.baseUrL+`api/delete/${this.processIdNow}`).then((res)=>{
|
||||
if(res.data.code==0){
|
||||
this.$message.success('删除成功')
|
||||
}
|
||||
else{
|
||||
this.$message.error('删除失败')
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
this.$message.error('删除失败')
|
||||
});
|
||||
}).catch(() => {
|
||||
this.$message({
|
||||
type: 'info',
|
||||
message: '已取消删除'
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
},
|
||||
getResult(){
|
||||
|
||||
axios.get(this.baseUrL+`api/predict/${this.processIdNow}`).then((res)=>{
|
||||
if(res.data.code ==0){
|
||||
this.response_=res.data.data;
|
||||
console.log(res.data);
|
||||
console.log(this.response_[0])
|
||||
|
||||
document.getElementById('100').src = this.response_[0];
|
||||
var result_class_degree = JSON.parse(this.response_[1]);
|
||||
|
||||
// console.log(result_class_degree)
|
||||
// this.resultClass=result_class_degree[0].class;
|
||||
// this.resultScore=result_class_degree[0].score;
|
||||
this.tableData=[]
|
||||
for (var i = 0; i < result_class_degree.length; i++) {
|
||||
var rowData = result_class_degree[i];
|
||||
|
||||
|
||||
this.tableData.push({
|
||||
class: rowData.class,
|
||||
score: rowData.score
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
else{
|
||||
this.$message.error('失败')
|
||||
}
|
||||
}).catch(() => {
|
||||
this.$message.error('文件已删除或未上传文件')
|
||||
});
|
||||
},
|
||||
|
||||
|
||||
deactivated () {
|
||||
clearInterval(this.timer);
|
||||
},
|
||||
beforeDestroy () {
|
||||
clearInterval(this.timer);
|
||||
},
|
||||
}
|
||||
};
|
||||
</script>
|
||||
|
||||
<style lang="less" scoped>
|
||||
#app {
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
/* 布局相关 start */
|
||||
.container {
|
||||
height: 100%;
|
||||
background-image: url("~@/assets/image/background.png");
|
||||
background-size: cover;
|
||||
}
|
||||
|
||||
.content-title {
|
||||
height: 2rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
font-size: 1.2rem;
|
||||
font-weight: bolder;
|
||||
color: #0166e2;
|
||||
img {
|
||||
margin-right: 5px;
|
||||
}
|
||||
}
|
||||
|
||||
.content-body {
|
||||
height: calc(100% - 12.2rem);
|
||||
margin-top: 1rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: space-between;
|
||||
}
|
||||
.content-bodyhead{
|
||||
text-align:left;
|
||||
border:solid rgb(118, 103, 118);
|
||||
padding: 10px;
|
||||
color: #8c939d;
|
||||
}
|
||||
.content-bodyhead text{
|
||||
padding: 4px 0;
|
||||
}
|
||||
|
||||
|
||||
.section-box {
|
||||
box-sizing: border-box;
|
||||
padding: 1rem;
|
||||
border: 2px solid #ffffff;
|
||||
}
|
||||
.content-bottom {
|
||||
display: flex;
|
||||
flex-wrap: nowrap;
|
||||
position: relative;
|
||||
height: calc(100% - 7rem);
|
||||
margin-top: -1rem;
|
||||
padding: 0 1rem 1rem 1rem;
|
||||
}
|
||||
.content-bottom .title {
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
background-color: #2e325a;
|
||||
color: #c4c6c9;
|
||||
bottom: 10px;
|
||||
|
||||
}
|
||||
.center{
|
||||
display: flex;
|
||||
align-content: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.content-left {
|
||||
width: 27%;
|
||||
margin-right: 1rem;
|
||||
}
|
||||
|
||||
.content-middle {
|
||||
width: 46%;
|
||||
margin-right: 1rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: space-between;
|
||||
}
|
||||
|
||||
.content-right {
|
||||
width: 27%;
|
||||
text-align: center;
|
||||
// border-style:solid;
|
||||
border-color:#cfc5c5;
|
||||
background: rgba(255, 255, 255, 0.2); /* 使用半透明白色作为背景颜色 */
|
||||
backdrop-filter: blur(10px); /* 使用模糊滤镜来创建磨砂玻璃效果 */
|
||||
}
|
||||
|
||||
//右边盒子的布局
|
||||
.content-right .title {
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
background-color: #2e325a;
|
||||
color: #c4c6c9;
|
||||
bottom: 10px;
|
||||
}
|
||||
|
||||
.content {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.row {
|
||||
display: flex;
|
||||
flex-wrap: nowrap;
|
||||
margin: 10px 0 0 0;
|
||||
}
|
||||
|
||||
.column {
|
||||
flex: 1;
|
||||
border: 2px solid #ccc;
|
||||
padding: 10px;
|
||||
margin: 0 5px;
|
||||
overflow: hidden;
|
||||
color:#c4c6c9;
|
||||
background-color: #16388e;
|
||||
// max-height: 100%; /* 限制行的最大高度 */
|
||||
}
|
||||
.column img {
|
||||
max-width: 66%; /* 图片的最大宽度为列的宽度 */
|
||||
height: auto; /* 保持图片的纵横比 */
|
||||
}
|
||||
/* 布局相关 end */
|
||||
|
||||
/* header start */
|
||||
|
||||
.content-top {
|
||||
position: relative;
|
||||
height: 7rem;
|
||||
text-align: center;
|
||||
font-style: italic;
|
||||
font-weight: bolder;
|
||||
background: url("@/assets/image/background-top.jpg") no-repeat center;
|
||||
background-size: 100% 100%;
|
||||
background-position-y: -0.6875rem;
|
||||
color: rgb(194, 193, 194);
|
||||
.title {
|
||||
line-height: 5rem;
|
||||
font-size: 2rem;
|
||||
letter-spacing: 2px;
|
||||
}
|
||||
|
||||
.time-box {
|
||||
position: absolute;
|
||||
left: 5rem;
|
||||
top: 0.5rem;
|
||||
display: flex;
|
||||
height: 30%;
|
||||
font-size: 1.0625rem;
|
||||
}
|
||||
|
||||
#time,
|
||||
#weather {
|
||||
padding: 0.5rem 1.2rem 0 1.2rem;
|
||||
border-top: 2px solid #0166e2;
|
||||
letter-spacing: 1px;
|
||||
}
|
||||
|
||||
.weather-box {
|
||||
position: absolute;
|
||||
right: 5rem;
|
||||
top: 0.5rem;
|
||||
display: flex;
|
||||
height: 30%;
|
||||
font-size: 1.2rem;
|
||||
}
|
||||
|
||||
.line-box {
|
||||
width: 3rem;
|
||||
height: 30%;
|
||||
border-top: 2px solid #0166e2;
|
||||
}
|
||||
|
||||
.line-box1 {
|
||||
margin-left: 1rem;
|
||||
}
|
||||
.line-box2 {
|
||||
margin-right: 1rem;
|
||||
}
|
||||
}
|
||||
|
||||
/* header end */
|
||||
|
||||
.roadDistribute-box {
|
||||
height: 100%;
|
||||
|
||||
.distribute-textBox {
|
||||
height: 40%;
|
||||
padding: 1rem;
|
||||
line-height: 3.2rem;
|
||||
letter-spacing: 1px;
|
||||
font-size: 1.5rem;
|
||||
color: #ffffff;
|
||||
}
|
||||
}
|
||||
|
||||
.imagebox {
|
||||
display: flex;
|
||||
justify-content: center; //弹性盒子对象在主轴上的对齐方式
|
||||
align-items: center; //定义flex子项在flex容器的当前行的侧轴(纵轴)方向上的对齐方式。
|
||||
box-sizing: border-box;
|
||||
border-color: #ffffff;
|
||||
background: url("@/assets/image/background-map.png") no-repeat center;
|
||||
position:relative
|
||||
}
|
||||
.imagebox img {
|
||||
width: 300px;
|
||||
height: 300px;
|
||||
border: 1px solid #fff;
|
||||
background-color: #fff;
|
||||
|
||||
}
|
||||
.image-box {
|
||||
display: flex;
|
||||
justify-content: center; //弹性盒子对象在主轴上的对齐方式
|
||||
align-items: center; //定义flex子项在flex容器的当前行的侧轴(纵轴)方向上的对齐方式。
|
||||
box-sizing: border-box;
|
||||
background-size: 100% 100%;
|
||||
background: url("@/assets/image/background-map.png") no-repeat center;
|
||||
}
|
||||
.image-box img {
|
||||
width: 87%;
|
||||
height: 87%;
|
||||
background-color: #fff;
|
||||
}
|
||||
|
||||
.eventStatistics-box {
|
||||
height: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: space-between;
|
||||
}
|
||||
.eventStatistics-box img {
|
||||
width: 100%;
|
||||
height: 50%;
|
||||
background-color: #fff;
|
||||
}
|
||||
.btnArea {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
justify-content: flex-start;
|
||||
align-items: center;
|
||||
height: 10%;
|
||||
}
|
||||
.btn {
|
||||
background-image: linear-gradient(to bottom right, #196cd8, #3c85c0, #63a6e4);
|
||||
overflow: hidden;
|
||||
font-size: 20px;
|
||||
width: 60%;
|
||||
height: 80%;
|
||||
color: #ffffff;
|
||||
border: 0 solid #fff;
|
||||
opacity: 0.8;
|
||||
margin: 0 2px;
|
||||
border-radius: 20px;
|
||||
}
|
||||
.avatar-uploader .el-upload {
|
||||
border: 1px dashed #d9d9d9;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
}
|
||||
.avatar-uploader .el-upload:hover {
|
||||
border-color: #409EFF;
|
||||
}
|
||||
.avatar-uploader-icon {
|
||||
font-size: 28px;
|
||||
color: #8c939d;
|
||||
width: 178px;
|
||||
height: 178px;
|
||||
line-height: 178px;
|
||||
text-align: center;
|
||||
}
|
||||
.avatar {
|
||||
width: 178px;
|
||||
height: 178px;
|
||||
display: block;
|
||||
}
|
||||
.above_txt{
|
||||
position:absolute;top:108px;left:108px;
|
||||
}
|
||||
.table{
|
||||
margin-top :2rem;
|
||||
line-height: 2rem;
|
||||
letter-spacing: 0px;
|
||||
font-size:1.3rem;
|
||||
}
|
||||
.table .title{
|
||||
margin-top :2rem;
|
||||
letter-spacing: 0px;
|
||||
font-size:1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,10 @@
|
||||
import request from '@/utils/request'
|
||||
|
||||
// 获取天气情况
|
||||
export function getWeather (param) {
|
||||
return request({
|
||||
url: '/getWeather',
|
||||
method: 'get',
|
||||
params: param
|
||||
})
|
||||
}
|
||||
|
After Width: | Height: | Size: 126 KiB |
|
After Width: | Height: | Size: 128 KiB |
|
After Width: | Height: | Size: 33 KiB |
|
After Width: | Height: | Size: 35 KiB |
|
After Width: | Height: | Size: 1.3 MiB |
|
After Width: | Height: | Size: 174 KiB |
|
After Width: | Height: | Size: 134 KiB |
|
After Width: | Height: | Size: 149 KiB |
|
After Width: | Height: | Size: 145 KiB |
|
After Width: | Height: | Size: 159 KiB |
|
After Width: | Height: | Size: 38 KiB |
|
After Width: | Height: | Size: 146 KiB |
|
After Width: | Height: | Size: 436 KiB |
|
After Width: | Height: | Size: 118 KiB |
|
After Width: | Height: | Size: 4.9 KiB |
@@ -0,0 +1,14 @@
|
||||
import Vue from 'vue'
|
||||
import App from './App.vue'
|
||||
|
||||
Vue.config.productionTip = false
|
||||
|
||||
import ElementUI from 'element-ui'
|
||||
import 'element-ui/lib/theme-chalk/index.css'
|
||||
Vue.use(ElementUI)
|
||||
|
||||
import '../mock/index'
|
||||
|
||||
new Vue({
|
||||
render: h => h(App),
|
||||
}).$mount('#app')
|
||||
@@ -0,0 +1,89 @@
|
||||
import axios from 'axios'
|
||||
import { Notification } from 'element-ui'
|
||||
// 创建axios实例
|
||||
const service = axios.create({
|
||||
baseURL: '/api',
|
||||
timeout: 80000, // 请求超时时间
|
||||
withCredentials: true,
|
||||
// crossDomain: true
|
||||
})
|
||||
|
||||
// request拦截器
|
||||
service.interceptors.request.use(
|
||||
config => {
|
||||
// if (getToken()) {
|
||||
// config.headers['Authorization'] = getToken() // 让每个请求携带自定义token 请根据实际情况自行修改
|
||||
// }
|
||||
var lang = localStorage.getItem('lang')//因为项目中使用到了i18n国际化语言配置,请根据实际情况自行修改
|
||||
if (!lang) {
|
||||
lang = 'zh_CN'
|
||||
}
|
||||
config.headers['Accept-Language'] = lang.replace(/_/g, '-')
|
||||
config.headers['Content-Type'] = 'application/json'
|
||||
return config
|
||||
},
|
||||
error => {
|
||||
Promise.reject(error)
|
||||
}
|
||||
)
|
||||
|
||||
// response 拦截器
|
||||
service.interceptors.response.use(
|
||||
response => {
|
||||
return response.data
|
||||
},
|
||||
error => {
|
||||
// 兼容blob下载出错json提示
|
||||
if (error.response.data instanceof Blob && error.response.data.type.toLowerCase().indexOf('json') !== -1) {
|
||||
const reader = new FileReader()
|
||||
reader.readAsText(error.response.data, 'utf-8')
|
||||
reader.onload = function () {
|
||||
const errorMsg = JSON.parse(reader.result).message
|
||||
Notification.error({
|
||||
title: errorMsg,
|
||||
duration: 5000
|
||||
})
|
||||
}
|
||||
} else {
|
||||
let code = 0
|
||||
try {
|
||||
code = error.response.data.status
|
||||
} catch (e) {
|
||||
if (error.toString().indexOf('Error: timeout') !== -1) {
|
||||
Notification.error({
|
||||
title: '网络请求超时',
|
||||
duration: 5000
|
||||
})
|
||||
return Promise.reject(error)
|
||||
}
|
||||
}
|
||||
|
||||
if (code) {
|
||||
// if (code === 401) {
|
||||
// store.dispatch('LogOut').then(() => {
|
||||
// // 用户登录界面提示
|
||||
// Cookies.set('point', 401)
|
||||
// location.reload()
|
||||
// })
|
||||
// } else if (code === 403) {
|
||||
// router.push({ path: '/401' })
|
||||
// } else {
|
||||
// const errorMsg = error.response.data.message
|
||||
// if (errorMsg !== undefined) {
|
||||
// Notification.error({
|
||||
// title: errorMsg,
|
||||
// duration: 0
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
} else {
|
||||
Notification.error({
|
||||
title: '接口请求失败',
|
||||
duration: 5000
|
||||
})
|
||||
}
|
||||
}
|
||||
return Promise.reject(error)
|
||||
}
|
||||
)
|
||||
export default service
|
||||
@@ -0,0 +1,16 @@
|
||||
const { defineConfig } = require('@vue/cli-service')
|
||||
module.exports = defineConfig({
|
||||
transpileDependencies: true,
|
||||
devServer: {
|
||||
open: true, // 运行后自动打开浏览器
|
||||
port: 8080,
|
||||
// proxy: {
|
||||
// '/api': {
|
||||
// target: 'http://192.168.1.102:8088',
|
||||
// ws: false,
|
||||
// changeOrigin: true,
|
||||
// logLevel: 'debug'
|
||||
// }
|
||||
// }
|
||||
},
|
||||
})
|
||||
@@ -0,0 +1,140 @@
|
||||
# ignore map, miou, datasets
|
||||
map_out/
|
||||
miou_out/
|
||||
VOCdevkit/
|
||||
datasets/
|
||||
Medical_Datasets/
|
||||
lfw/
|
||||
logs/
|
||||
model_data/
|
||||
.temp_map_out/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
@@ -0,0 +1,209 @@
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001406.jpg 243,20,271,419,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001872.jpg 72,116,437,352,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000618.jpg 70,168,421,254,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000583.jpg 200,96,310,396,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000784.jpg 141,104,380,360,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000210.jpg 119,106,383,370,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000915.jpg 222,50,294,422,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000004.jpg 174,142,327,295,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001580.jpg 130,103,408,361,4
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002486.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002287.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000883.jpg 111,150,402,285,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001285.jpg 196,70,326,416,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000640.jpg 125,27,378,412,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001855.jpg 201,109,294,343,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001989.jpg 128,165,406,293,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001300.jpg 154,6,326,410,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000798.jpg 90,99,410,281,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001990.jpg 157,90,329,305,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002152.jpg 186,123,328,339,6
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002407.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000291.jpg 189,160,317,290,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002381.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000657.jpg 211,143,290,311,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002015.jpg 127,116,419,244,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001517.jpg 110,114,411,298,4
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000383.jpg 136,126,342,333,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000154.jpg 141,112,367,334,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000480.jpg 194,68,316,406,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000846.jpg 230,58,306,391,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001419.jpg 207,52,305,389,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002539.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001318.jpg 70,23,382,437,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001805.jpg 74,154,416,292,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001021.jpg 53,145,443,228,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000484.jpg 109,149,414,240,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002356.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001985.jpg 194,105,359,329,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002352.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001148.jpg 89,89,413,354,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001066.jpg 152,133,381,319,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001489.jpg 206,94,307,353,4
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001852.jpg 60,144,427,257,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000762.jpg 72,195,449,273,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000709.jpg 154,111,358,322,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001017.jpg 209,51,286,422,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001785.jpg 110,185,404,310,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002260.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001718.jpg 1,122,475,248,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002128.jpg 130,117,392,282,6
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000898.jpg 47,97,483,371,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000503.jpg 115,54,350,382,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000961.jpg 33,162,481,253,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001691.jpg 153,57,352,348,4
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000337.jpg 183,159,311,284,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002061.jpg 116,109,396,312,6
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002143.jpg 117,63,398,320,6
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000735.jpg 155,56,334,371,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001213.jpg 150,45,354,392,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001251.jpg 229,80,284,386,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002163.jpg 110,137,417,330,6
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002448.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000086.jpg 138,123,352,332,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002581.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001765.jpg 186,148,332,302,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001759.jpg 109,22,342,374,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001772.jpg 166,15,340,388,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001910.jpg 188,128,286,382,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002132.jpg 100,98,453,371,6
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001648.jpg 99,162,361,257,4
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001762.jpg 223,114,318,340,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002157.jpg 204,124,314,332,6
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001701.jpg 136,71,360,420,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001666.jpg 135,91,370,318,4
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001850.jpg 53,158,452,307,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000122.jpg 191,167,305,278,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001741.jpg 172,65,355,414,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001519.jpg 60,131,433,324,4
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001252.jpg 71,32,452,415,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002371.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000936.jpg 129,59,384,410,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002173.jpg 130,135,403,301,6
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001579.jpg 25,54,404,421,4
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002307.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001681.jpg 115,108,379,277,4
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000547.jpg 135,169,343,272,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002215.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000861.jpg 104,62,370,424,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002398.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000893.jpg 217,41,295,420,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002064.jpg 143,113,365,351,6
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002522.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000856.jpg 107,155,383,234,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000228.jpg 176,136,356,315,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000633.jpg 109,117,416,316,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000022.jpg 155,105,341,290,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001088.jpg 80,145,431,305,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002565.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002276.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000308.jpg 90,68,385,359,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000828.jpg 218,51,309,354,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001180.jpg 87,138,374,342,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000264.jpg 118,128,317,329,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000782.jpg 191,5,318,396,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000157.jpg 103,79,402,380,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001228.jpg 181,62,333,388,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002034.jpg 144,86,329,412,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001423.jpg 147,89,382,378,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000060.jpg 117,118,393,393,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001720.jpg 51,159,434,298,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001024.jpg 26,19,460,422,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001766.jpg 133,31,370,399,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000359.jpg 109,66,369,329,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000943.jpg 114,193,395,248,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000216.jpg 141,106,393,363,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001768.jpg 172,1,330,396,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002488.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002186.jpg 154,87,351,397,6
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000469.jpg 67,152,426,307,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002406.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000506.jpg 120,160,417,297,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002077.jpg 114,107,380,324,6
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000093.jpg 137,129,358,347,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002239.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001714.jpg 187,51,314,393,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000849.jpg 119,112,394,309,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001811.jpg 190,79,381,364,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001173.jpg 110,99,401,364,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000047.jpg 159,115,365,318,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001722.jpg 150,116,378,300,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001605.jpg 160,35,352,355,4
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002266.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000971.jpg 73,180,419,218,2 76,204,418,246,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002177.jpg 120,63,391,374,6
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001166.jpg 191,37,329,444,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002304.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000962.jpg 107,29,397,381,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001005.jpg 43,122,466,347,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001556.jpg 180,127,328,304,4
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001776.jpg 171,147,316,347,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002370.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001903.jpg 206,60,284,402,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000651.jpg 129,137,437,323,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001212.jpg 70,77,445,339,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002274.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002102.jpg 88,97,407,333,6
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000705.jpg 135,130,382,325,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002455.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000776.jpg 76,85,413,310,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002261.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001132.jpg 32,98,440,280,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000582.jpg 180,103,293,351,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001026.jpg 125,177,387,237,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002011.jpg 126,78,391,437,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001115.jpg 212,44,293,403,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001839.jpg 70,167,438,282,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001122.jpg 95,30,405,442,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002063.jpg 143,88,357,374,6
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001283.jpg 43,117,420,359,3 212,132,334,443,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001725.jpg 210,139,287,349,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000973.jpg 152,7,373,426,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000237.jpg 115,108,380,374,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000942.jpg 55,178,443,274,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001745.jpg 87,118,425,310,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002360.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001790.jpg 144,112,405,366,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002312.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000857.jpg 137,205,358,264,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001594.jpg 211,66,300,399,4
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000663.jpg 176,123,346,336,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001964.jpg 109,165,392,302,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001395.jpg 76,66,408,378,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000843.jpg 71,111,418,400,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000371.jpg 133,119,370,352,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000795.jpg 225,102,311,386,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001211.jpg 107,21,420,407,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000835.jpg 118,108,343,320,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001366.jpg 85,81,433,376,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002242.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000677.jpg 86,112,421,341,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002289.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000381.jpg 135,117,360,334,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000722.jpg 154,180,382,267,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002336.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000350.jpg 141,72,415,343,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002311.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002354.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000485.jpg 221,137,291,331,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000130.jpg 118,73,369,321,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001465.jpg 16,90,459,390,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000964.jpg 208,30,318,402,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000457.jpg 220,75,300,328,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002018.jpg 138,47,360,438,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000768.jpg 104,88,400,381,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000479.jpg 187,82,340,396,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000811.jpg 8,133,479,332,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000834.jpg 31,39,489,388,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000007.jpg 138,112,373,343,0
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002146.jpg 189,123,323,341,6
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001070.jpg 113,123,406,333,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002579.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000787.jpg 166,45,325,429,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000594.jpg 122,63,393,434,1
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002562.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001371.jpg 158,93,323,397,3
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/000920.jpg 68,90,416,366,2
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002458.jpg
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/001826.jpg 94,149,386,384,5 26,37,76,152,5
|
||||
/media/SSD/luyuetong-data/yolox-pytorch/VOCdevkit/VOC2007/JPEGImages/002434.jpg
|
||||
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright Megvii, Base Detection
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -0,0 +1,180 @@
|
||||
## YOLOX:You Only Look Once目标检测模型在Pytorch当中的实现
|
||||
---
|
||||
|
||||
## 目录
|
||||
1. [仓库更新 Top News](#仓库更新)
|
||||
2. [相关仓库 Related code](#相关仓库)
|
||||
3. [性能情况 Performance](#性能情况)
|
||||
4. [实现的内容 Achievement](#实现的内容)
|
||||
5. [所需环境 Environment](#所需环境)
|
||||
6. [文件下载 Download](#文件下载)
|
||||
7. [训练步骤 How2train](#训练步骤)
|
||||
8. [预测步骤 How2predict](#预测步骤)
|
||||
9. [评估步骤 How2eval](#评估步骤)
|
||||
10. [参考资料 Reference](#Reference)
|
||||
|
||||
## Top News
|
||||
**`2022-04`**:**支持多GPU训练,新增各个种类目标数量计算,新增heatmap。**
|
||||
|
||||
**`2022-03`**:**进行了大幅度的更新,支持step、cos学习率下降法、支持adam、sgd优化器选择、支持学习率根据batch_size自适应调整、新增图片裁剪。**
|
||||
BiliBili视频中的原仓库地址为:https://github.com/bubbliiiing/yolox-pytorch/tree/bilibili
|
||||
|
||||
**`2021-10`**:**创建仓库,支持不同尺寸模型训练、支持大量可调整参数,支持fps、视频预测、批量预测等功能。**
|
||||
|
||||
## 相关仓库
|
||||
| 模型 | 路径 |
|
||||
| :----- | :----- |
|
||||
YoloV3 | https://github.com/bubbliiiing/yolo3-pytorch
|
||||
Efficientnet-Yolo3 | https://github.com/bubbliiiing/efficientnet-yolo3-pytorch
|
||||
YoloV4 | https://github.com/bubbliiiing/yolov4-pytorch
|
||||
YoloV4-tiny | https://github.com/bubbliiiing/yolov4-tiny-pytorch
|
||||
Mobilenet-Yolov4 | https://github.com/bubbliiiing/mobilenet-yolov4-pytorch
|
||||
YoloV5-V5.0 | https://github.com/bubbliiiing/yolov5-pytorch
|
||||
YoloV5-V6.1 | https://github.com/bubbliiiing/yolov5-v6.1-pytorch
|
||||
YoloX | https://github.com/bubbliiiing/yolox-pytorch
|
||||
YoloV7 | https://github.com/bubbliiiing/yolov7-pytorch
|
||||
YoloV7-tiny | https://github.com/bubbliiiing/yolov7-tiny-pytorch
|
||||
|
||||
## 性能情况
|
||||
| 训练数据集 | 权值文件名称 | 测试数据集 | 输入图片大小 | mAP 0.5:0.95 | mAP 0.5 |
|
||||
| :-----: | :-----: | :------: | :------: | :------: | :-----: |
|
||||
| COCO-Train2017 | [yolox_nano.pth](https://github.com/bubbliiiing/yolox-pytorch/releases/download/v1.0/yolox_nano.pth) | COCO-Val2017 | 640x640 | 27.4 | 44.5
|
||||
| COCO-Train2017 | [yolox_tiny.pth](https://github.com/bubbliiiing/yolox-pytorch/releases/download/v1.0/yolox_tiny.pth) | COCO-Val2017 | 640x640 | 34.7 | 53.6
|
||||
| COCO-Train2017 | [yolox_s.pth](https://github.com/bubbliiiing/yolox-pytorch/releases/download/v1.0/yolox_s.pth) | COCO-Val2017 | 640x640 | 38.2 | 57.7
|
||||
| COCO-Train2017 | [yolox_m.pth](https://github.com/bubbliiiing/yolox-pytorch/releases/download/v1.0/yolox_m.pth) | COCO-Val2017 | 640x640 | 44.8 | 63.9
|
||||
| COCO-Train2017 | [yolox_l.pth](https://github.com/bubbliiiing/yolox-pytorch/releases/download/v1.0/yolox_l.pth) | COCO-Val2017 | 640x640 | 47.9 | 66.6
|
||||
| COCO-Train2017 | [yolox_x.pth](https://github.com/bubbliiiing/yolox-pytorch/releases/download/v1.0/yolox_x.pth) | COCO-Val2017 | 640x640 | 49.0 | 67.7
|
||||
|
||||
## 实现的内容
|
||||
- [x] 主干特征提取网络:使用了Focus网络结构。
|
||||
- [x] 分类回归层:Decoupled Head,在YoloX中,Yolo Head被分为了分类回归两部分,最后预测的时候才整合在一起。
|
||||
- [x] 训练用到的小技巧:Mosaic数据增强、IOU和GIOU、学习率余弦退火衰减。
|
||||
- [x] Anchor Free:不使用先验框
|
||||
- [x] SimOTA:为不同大小的目标动态匹配正样本。
|
||||
|
||||
## 所需环境
|
||||
pytorch==1.2.0
|
||||
|
||||
## 文件下载
|
||||
训练所需的权值可在百度网盘中下载。
|
||||
链接: https://pan.baidu.com/s/1bi2UBwwIHES0OpLeyYuBFg
|
||||
提取码: f4ni
|
||||
|
||||
VOC数据集下载地址如下,里面已经包括了训练集、测试集、验证集(与测试集一样),无需再次划分:
|
||||
链接: https://pan.baidu.com/s/19Mw2u_df_nBzsC2lg20fQA
|
||||
提取码: j5ge
|
||||
|
||||
## 训练步骤
|
||||
### a、训练VOC07+12数据集
|
||||
1. 数据集的准备
|
||||
**本文使用VOC格式进行训练,训练前需要下载好VOC07+12的数据集,解压后放在根目录**
|
||||
|
||||
2. 数据集的处理
|
||||
修改voc_annotation.py里面的annotation_mode=2,运行voc_annotation.py生成根目录下的2007_train.txt和2007_val.txt。
|
||||
|
||||
3. 开始网络训练
|
||||
train.py的默认参数用于训练VOC数据集,直接运行train.py即可开始训练。
|
||||
|
||||
4. 训练结果预测
|
||||
训练结果预测需要用到两个文件,分别是yolo.py和predict.py。我们首先需要去yolo.py里面修改model_path以及classes_path,这两个参数必须要修改。
|
||||
**model_path指向训练好的权值文件,在logs文件夹里。
|
||||
classes_path指向检测类别所对应的txt。**
|
||||
完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。
|
||||
|
||||
### b、训练自己的数据集
|
||||
1. 数据集的准备
|
||||
**本文使用VOC格式进行训练,训练前需要自己制作好数据集,**
|
||||
训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。
|
||||
训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。
|
||||
|
||||
2. 数据集的处理
|
||||
在完成数据集的摆放之后,我们需要利用voc_annotation.py获得训练用的2007_train.txt和2007_val.txt。
|
||||
修改voc_annotation.py里面的参数。第一次训练可以仅修改classes_path,classes_path用于指向检测类别所对应的txt。
|
||||
训练自己的数据集时,可以自己建立一个cls_classes.txt,里面写自己所需要区分的类别。
|
||||
model_data/cls_classes.txt文件内容为:
|
||||
```python
|
||||
cat
|
||||
dog
|
||||
...
|
||||
```
|
||||
修改voc_annotation.py中的classes_path,使其对应cls_classes.txt,并运行voc_annotation.py。
|
||||
|
||||
3. 开始网络训练
|
||||
**训练的参数较多,均在train.py中,大家可以在下载库后仔细看注释,其中最重要的部分依然是train.py里的classes_path。**
|
||||
**classes_path用于指向检测类别所对应的txt,这个txt和voc_annotation.py里面的txt一样!训练自己的数据集必须要修改!**
|
||||
修改完classes_path后就可以运行train.py开始训练了,在训练多个epoch后,权值会生成在logs文件夹中。
|
||||
|
||||
4. 训练结果预测
|
||||
训练结果预测需要用到两个文件,分别是yolo.py和predict.py。在yolo.py里面修改model_path以及classes_path。
|
||||
**model_path指向训练好的权值文件,在logs文件夹里。
|
||||
classes_path指向检测类别所对应的txt。**
|
||||
完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。
|
||||
|
||||
## 预测步骤
|
||||
### a、使用预训练权重
|
||||
1. 下载完库后解压,在百度网盘下载yolo_weights.pth,放入model_data,运行predict.py,输入
|
||||
```python
|
||||
img/street.jpg
|
||||
```
|
||||
2. 在predict.py里面进行设置可以进行fps测试和video视频检测。
|
||||
### b、使用自己训练的权重
|
||||
1. 按照训练步骤训练。
|
||||
2. 在yolo.py文件里面,在如下部分修改model_path和classes_path使其对应训练好的文件;**model_path对应logs文件夹下面的权值文件,classes_path是model_path对应分的类**。
|
||||
```python
|
||||
_defaults = {
|
||||
#--------------------------------------------------------------------------#
|
||||
# 使用自己训练好的模型进行预测一定要修改model_path和classes_path!
|
||||
# model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
|
||||
# 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
|
||||
#--------------------------------------------------------------------------#
|
||||
"model_path" : 'model_data/yolox_s.pth',
|
||||
"classes_path" : 'model_data/coco_classes.txt',
|
||||
#---------------------------------------------------------------------#
|
||||
# 输入图片的大小,必须为32的倍数。
|
||||
#---------------------------------------------------------------------#
|
||||
"input_shape" : [640, 640],
|
||||
#---------------------------------------------------------------------#
|
||||
# 所使用的YoloX的版本。nano、tiny、s、m、l、x
|
||||
#---------------------------------------------------------------------#
|
||||
"phi" : 's',
|
||||
#---------------------------------------------------------------------#
|
||||
# 只有得分大于置信度的预测框会被保留下来
|
||||
#---------------------------------------------------------------------#
|
||||
"confidence" : 0.5,
|
||||
#---------------------------------------------------------------------#
|
||||
# 非极大抑制所用到的nms_iou大小
|
||||
#---------------------------------------------------------------------#
|
||||
"nms_iou" : 0.3,
|
||||
#---------------------------------------------------------------------#
|
||||
# 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
|
||||
# 在多次测试后,发现关闭letterbox_image直接resize的效果更好
|
||||
#---------------------------------------------------------------------#
|
||||
"letterbox_image" : True,
|
||||
#-------------------------------#
|
||||
# 是否使用Cuda
|
||||
# 没有GPU可以设置成False
|
||||
#-------------------------------#
|
||||
"cuda" : True,
|
||||
}
|
||||
```
|
||||
3. 运行predict.py,输入
|
||||
```python
|
||||
img/street.jpg
|
||||
```
|
||||
4. 在predict.py里面进行设置可以进行fps测试和video视频检测。
|
||||
|
||||
## 评估步骤
|
||||
### a、评估VOC07+12的测试集
|
||||
1. 本文使用VOC格式进行评估。VOC07+12已经划分好了测试集,无需利用voc_annotation.py生成ImageSets文件夹下的txt。
|
||||
2. 在yolo.py里面修改model_path以及classes_path。**model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。**
|
||||
3. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。
|
||||
|
||||
### b、评估自己的数据集
|
||||
1. 本文使用VOC格式进行评估。
|
||||
2. 如果在训练前已经运行过voc_annotation.py文件,代码会自动将数据集划分成训练集、验证集和测试集。如果想要修改测试集的比例,可以修改voc_annotation.py文件下的trainval_percent。trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1。train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1。
|
||||
3. 利用voc_annotation.py划分测试集后,前往get_map.py文件修改classes_path,classes_path用于指向检测类别所对应的txt,这个txt和训练时的txt一样。评估自己的数据集必须要修改。
|
||||
4. 在yolo.py里面修改model_path以及classes_path。**model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。**
|
||||
5. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。
|
||||
|
||||
## Reference
|
||||
https://github.com/Megvii-BaseDetection/YOLOX
|
||||
@@ -0,0 +1,17 @@
|
||||
<!--
|
||||
* @Author: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
|
||||
* @Date: 2023-03-31 20:04:39
|
||||
* @LastEditors: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
|
||||
* @LastEditTime: 2023-03-31 20:04:47
|
||||
* @FilePath: /luyuetong-data/yolox-pytorch/display_image.html
|
||||
* @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
|
||||
-->
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Display Image</title>
|
||||
</head>
|
||||
<body>
|
||||
<img src="{{ url_for('upload_image') }}" alt="Uploaded Image">
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,95 @@
|
||||
'''
|
||||
Author: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
|
||||
Date: 2023-03-29 20:22:34
|
||||
LastEditors: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
|
||||
LastEditTime: 2023-04-05 15:46:07
|
||||
FilePath: /luyuetong-data/yolox-pytorch/flask_on.py
|
||||
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://githaub.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
|
||||
'''
|
||||
from flask import Flask, make_response, request, jsonify,send_file
|
||||
from flask_cors import CORS
|
||||
from werkzeug.utils import secure_filename
|
||||
from PIL import Image
|
||||
from yolo import YOLO
|
||||
from predict import PREDICT
|
||||
import os
|
||||
import quanju
|
||||
import requests
|
||||
import urllib.request
|
||||
from io import BytesIO
|
||||
import base64
|
||||
import io
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
@app.route('/uploads', methods=['GET','POST'])
|
||||
def uploads():
|
||||
global filename
|
||||
# 检查是否收到了POST请求
|
||||
if request.method == 'POST':
|
||||
|
||||
|
||||
# 检查是否收到了文件
|
||||
if 'file' not in request.files:
|
||||
return jsonify({'error': 'No file uploaded.'}), 400
|
||||
|
||||
file = request.files['file']
|
||||
filename = file.read() #此时转换出来的是二进制内容
|
||||
file_path = filename.decode("utf-8")
|
||||
|
||||
print(file_path)
|
||||
print(f"file:{file}")
|
||||
|
||||
# 将blob转换为base64编码的字符串
|
||||
image_data = base64.b64encode(file.read()).decode('utf-8')
|
||||
|
||||
# 将base64编码的字符串解码为图片数据
|
||||
img = Image.open(io.BytesIO(base64.b64decode(image_data)))
|
||||
Image.open(img)
|
||||
# 将bytes类型传递给预测函数进行处理
|
||||
result = PREDICT(img)
|
||||
|
||||
|
||||
#
|
||||
# print(file)
|
||||
# filename = secure_filename(file.filename)
|
||||
# file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))
|
||||
# result = PREDICT()
|
||||
|
||||
#检查是否选择了文件
|
||||
if file.filename == '':
|
||||
return jsonify({'error': 'No file selected.'}), 400
|
||||
|
||||
|
||||
# # 检查文件类型是否合法
|
||||
# # if file and allowed_file(file.filename):
|
||||
# filename = secure_filename(file.filename)
|
||||
#保存文件到本地
|
||||
# file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))
|
||||
|
||||
#自定义一个全局变量承接名字(自定义库.变量)
|
||||
# quanju.mingzi = filename
|
||||
return "ok"
|
||||
|
||||
# 将文件路径传递给模型进行预测
|
||||
# result = PREDICT(os.path.join(app.config['UPLOAD_FOLDER'], filename))
|
||||
|
||||
# return send_file(f'./results/result_{quanju.mingzi}', mimetype='image/jpeg')
|
||||
# return send_file('results/result_{}'.format(quanju.mingzi), mimetype='image/jpeg')
|
||||
|
||||
# return jsonify({'error': 'Invalid file type.'}), 400
|
||||
|
||||
|
||||
|
||||
def allowed_file(filename):
|
||||
# 检查文件类型是否合法,这里仅接受jpeg、jpg、png类型的图片文件
|
||||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in {'jpeg', 'jpg', 'png'}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# app.config['UPLOAD_FOLDER'] = './uploads'
|
||||
app.config['UPLOAD_FOLDER'] = 'yolox-pytorch/uploads'
|
||||
# app.debug = False
|
||||
app.run(host='10.21.19.104', port=8000)
|
||||
@@ -0,0 +1,42 @@
|
||||
'''
|
||||
Author: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
|
||||
Date: 2023-03-29 20:22:34
|
||||
LastEditors: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
|
||||
LastEditTime: 2023-04-05 15:27:13
|
||||
FilePath: /luyuetong-data/yolox-pytorch/flask_on.py
|
||||
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://githaub.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
|
||||
'''
|
||||
import os
|
||||
|
||||
from flask import Flask, request, jsonify
|
||||
from gevent import pywsgi
|
||||
|
||||
from predict import PREDICT
|
||||
|
||||
app = Flask(__name__)
|
||||
# 设置未处理文件路径以及处理保存路径
|
||||
# dataPath = '/home/DiatomRecognition/upload'
|
||||
# resultPath = '/home/DiatomRecognition/result'
|
||||
# localhost 测试
|
||||
dataPath = 'E:/Document/project/temp/data'
|
||||
resultPath = r'E:\Document\project\temp/save'
|
||||
|
||||
|
||||
@app.route('/predict', methods=['POST'])
|
||||
def predict_one():
|
||||
if not os.path.exists(resultPath):
|
||||
# 如果不存在,创建文件夹及其所有父文件夹
|
||||
os.makedirs(resultPath)
|
||||
# 从请求参数中提取文件名
|
||||
fileName = request.form['fileName']
|
||||
# 将文件路径传递给模型进行预测
|
||||
result = PREDICT(dataPath, resultPath, fileName) # predicted_class,score如果需要再处理
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# app.debug = False
|
||||
server = pywsgi.WSGIServer(('0.0.0.0', 2000), app)
|
||||
app.config["JSON_AS_ASCII"] = False
|
||||
server.serve_forever()
|
||||
@@ -0,0 +1,138 @@
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils.utils import get_classes
|
||||
from utils.utils_map import get_coco_map, get_map
|
||||
from yolo import YOLO
|
||||
|
||||
if __name__ == "__main__":
|
||||
'''
|
||||
Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。
|
||||
默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。
|
||||
|
||||
受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值
|
||||
因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框,
|
||||
'''
|
||||
#------------------------------------------------------------------------------------------------------------------#
|
||||
# map_mode用于指定该文件运行时计算的内容
|
||||
# map_mode为0代表整个map计算流程,包括获得预测结果、获得真实框、计算VOC_map。
|
||||
# map_mode为1代表仅仅获得预测结果。
|
||||
# map_mode为2代表仅仅获得真实框。
|
||||
# map_mode为3代表仅仅计算VOC_map。
|
||||
# map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行
|
||||
#-------------------------------------------------------------------------------------------------------------------#
|
||||
map_mode = 0
|
||||
#--------------------------------------------------------------------------------------#
|
||||
# 此处的classes_path用于指定需要测量VOC_map的类别
|
||||
# 一般情况下与训练和预测所用的classes_path一致即可
|
||||
#--------------------------------------------------------------------------------------#
|
||||
classes_path = 'model_data/voc_classes.txt' #..
|
||||
#--------------------------------------------------------------------------------------#
|
||||
# MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。
|
||||
# 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。
|
||||
#
|
||||
# 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。
|
||||
# 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低,
|
||||
#--------------------------------------------------------------------------------------#
|
||||
MINOVERLAP = 0.5
|
||||
#--------------------------------------------------------------------------------------#
|
||||
# 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP
|
||||
# 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。
|
||||
#
|
||||
# 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。
|
||||
# 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。
|
||||
#--------------------------------------------------------------------------------------#
|
||||
confidence = 0.001
|
||||
#--------------------------------------------------------------------------------------#
|
||||
# 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。
|
||||
#
|
||||
# 该值一般不调整。
|
||||
#--------------------------------------------------------------------------------------#
|
||||
nms_iou = 0.5
|
||||
#---------------------------------------------------------------------------------------------------------------#
|
||||
# Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。
|
||||
#
|
||||
# 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。
|
||||
# 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。
|
||||
# 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。
|
||||
#---------------------------------------------------------------------------------------------------------------#
|
||||
score_threhold = 0.5
|
||||
#-------------------------------------------------------#
|
||||
# map_vis用于指定是否开启VOC_map计算的可视化
|
||||
#-------------------------------------------------------#
|
||||
map_vis = False
|
||||
#-------------------------------------------------------#
|
||||
# 指向VOC数据集所在的文件夹
|
||||
# 默认指向根目录下的VOC数据集
|
||||
#-------------------------------------------------------#
|
||||
VOCdevkit_path = 'VOCdevkit'
|
||||
#-------------------------------------------------------#
|
||||
# 结果输出的文件夹,默认为map_out
|
||||
#-------------------------------------------------------#
|
||||
map_out_path = 'map_out'
|
||||
|
||||
image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split()
|
||||
|
||||
if not os.path.exists(map_out_path):
|
||||
os.makedirs(map_out_path)
|
||||
if not os.path.exists(os.path.join(map_out_path, 'ground-truth')):
|
||||
os.makedirs(os.path.join(map_out_path, 'ground-truth'))
|
||||
if not os.path.exists(os.path.join(map_out_path, 'detection-results')):
|
||||
os.makedirs(os.path.join(map_out_path, 'detection-results'))
|
||||
if not os.path.exists(os.path.join(map_out_path, 'images-optional')):
|
||||
os.makedirs(os.path.join(map_out_path, 'images-optional'))
|
||||
|
||||
class_names, _ = get_classes(classes_path)
|
||||
|
||||
if map_mode == 0 or map_mode == 1:
|
||||
print("Load model.")
|
||||
yolo = YOLO(confidence = confidence, nms_iou = nms_iou)
|
||||
print("Load model done.")
|
||||
|
||||
print("Get predict result.")
|
||||
for image_id in tqdm(image_ids):
|
||||
image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg")
|
||||
image = Image.open(image_path)
|
||||
if map_vis:
|
||||
image.save(os.path.join(map_out_path, "images-optional/" + image_id + ".jpg"))
|
||||
yolo.get_map_txt(image_id, image, class_names, map_out_path)
|
||||
print("Get predict result done.")
|
||||
|
||||
if map_mode == 0 or map_mode == 2:
|
||||
print("Get ground truth result.")
|
||||
for image_id in tqdm(image_ids):
|
||||
with open(os.path.join(map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
|
||||
root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/"+image_id+".xml")).getroot()
|
||||
for obj in root.findall('object'):
|
||||
difficult_flag = False
|
||||
if obj.find('difficult')!=None:
|
||||
difficult = obj.find('difficult').text
|
||||
if int(difficult)==1:
|
||||
difficult_flag = True
|
||||
obj_name = obj.find('name').text
|
||||
if obj_name not in class_names:
|
||||
continue
|
||||
bndbox = obj.find('bndbox')
|
||||
left = bndbox.find('xmin').text
|
||||
top = bndbox.find('ymin').text
|
||||
right = bndbox.find('xmax').text
|
||||
bottom = bndbox.find('ymax').text
|
||||
|
||||
if difficult_flag:
|
||||
new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
|
||||
else:
|
||||
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
|
||||
print("Get ground truth result done.")
|
||||
|
||||
if map_mode == 0 or map_mode == 3:
|
||||
print("Get map.")
|
||||
get_map(MINOVERLAP, True, score_threhold = score_threhold, path = map_out_path)
|
||||
print("Get map done.")
|
||||
|
||||
if map_mode == 4:
|
||||
print("Get map.")
|
||||
get_coco_map(class_names = class_names, path = map_out_path)
|
||||
print("Get map done.")
|
||||
|
After Width: | Height: | Size: 50 KiB |
|
After Width: | Height: | Size: 43 KiB |
|
After Width: | Height: | Size: 50 KiB |
|
After Width: | Height: | Size: 41 KiB |
|
After Width: | Height: | Size: 50 KiB |
|
After Width: | Height: | Size: 67 KiB |
|
After Width: | Height: | Size: 45 KiB |
|
After Width: | Height: | Size: 40 KiB |
|
After Width: | Height: | Size: 437 KiB |
|
After Width: | Height: | Size: 253 KiB |
|
After Width: | Height: | Size: 222 KiB |
|
After Width: | Height: | Size: 255 KiB |
|
After Width: | Height: | Size: 230 KiB |
|
After Width: | Height: | Size: 253 KiB |
|
After Width: | Height: | Size: 305 KiB |
|
After Width: | Height: | Size: 237 KiB |
|
After Width: | Height: | Size: 230 KiB |
|
After Width: | Height: | Size: 2.1 MiB |
@@ -0,0 +1 @@
|
||||
#
|
||||
@@ -0,0 +1,231 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Copyright (c) Megvii, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class SiLU(nn.Module):
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
def get_activation(name="silu", inplace=True):
|
||||
if name == "silu":
|
||||
module = SiLU()
|
||||
elif name == "relu":
|
||||
module = nn.ReLU(inplace=inplace)
|
||||
elif name == "lrelu":
|
||||
module = nn.LeakyReLU(0.1, inplace=inplace)
|
||||
else:
|
||||
raise AttributeError("Unsupported act type: {}".format(name))
|
||||
return module
|
||||
|
||||
class Focus(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
|
||||
super().__init__()
|
||||
self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
|
||||
|
||||
def forward(self, x):
|
||||
patch_top_left = x[..., ::2, ::2]
|
||||
patch_bot_left = x[..., 1::2, ::2]
|
||||
patch_top_right = x[..., ::2, 1::2]
|
||||
patch_bot_right = x[..., 1::2, 1::2]
|
||||
x = torch.cat((patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), dim=1,)
|
||||
return self.conv(x)
|
||||
|
||||
class BaseConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"):
|
||||
super().__init__()
|
||||
pad = (ksize - 1) // 2
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, groups=groups, bias=bias)
|
||||
self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03)
|
||||
self.act = get_activation(act, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.bn(self.conv(x)))
|
||||
|
||||
def fuseforward(self, x):
|
||||
return self.act(self.conv(x))
|
||||
|
||||
class DWConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
|
||||
super().__init__()
|
||||
self.dconv = BaseConv(in_channels, in_channels, ksize=ksize, stride=stride, groups=in_channels, act=act,)
|
||||
self.pconv = BaseConv(in_channels, out_channels, ksize=1, stride=1, groups=1, act=act)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dconv(x)
|
||||
return self.pconv(x)
|
||||
|
||||
class SPPBottleneck(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"):
|
||||
super().__init__()
|
||||
hidden_channels = in_channels // 2
|
||||
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
|
||||
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes])
|
||||
conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
|
||||
self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = torch.cat([x] + [m(x) for m in self.m], dim=1)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
#--------------------------------------------------#
|
||||
# 残差结构的构建,小的残差结构
|
||||
#--------------------------------------------------#
|
||||
class Bottleneck(nn.Module):
|
||||
# Standard bottleneck
|
||||
def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
|
||||
super().__init__()
|
||||
hidden_channels = int(out_channels * expansion)
|
||||
Conv = DWConv if depthwise else BaseConv
|
||||
#--------------------------------------------------#
|
||||
# 利用1x1卷积进行通道数的缩减。缩减率一般是50%
|
||||
#--------------------------------------------------#
|
||||
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
|
||||
#--------------------------------------------------#
|
||||
# 利用3x3卷积进行通道数的拓张。并且完成特征提取
|
||||
#--------------------------------------------------#
|
||||
self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
|
||||
self.use_add = shortcut and in_channels == out_channels
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv2(self.conv1(x))
|
||||
if self.use_add:
|
||||
y = y + x
|
||||
return y
|
||||
|
||||
class CSPLayer(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, n=1, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
|
||||
# ch_in, ch_out, number, shortcut, groups, expansion
|
||||
super().__init__()
|
||||
hidden_channels = int(out_channels * expansion)
|
||||
#--------------------------------------------------#
|
||||
# 主干部分的初次卷积
|
||||
#--------------------------------------------------#
|
||||
self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
|
||||
#--------------------------------------------------#
|
||||
# 大的残差边部分的初次卷积
|
||||
#--------------------------------------------------#
|
||||
self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
|
||||
#-----------------------------------------------#
|
||||
# 对堆叠的结果进行卷积的处理
|
||||
#-----------------------------------------------#
|
||||
self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
|
||||
|
||||
#--------------------------------------------------#
|
||||
# 根据循环的次数构建上述Bottleneck残差结构
|
||||
#--------------------------------------------------#
|
||||
module_list = [Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n)]
|
||||
self.m = nn.Sequential(*module_list)
|
||||
|
||||
def forward(self, x):
|
||||
#-------------------------------#
|
||||
# x_1是主干部分
|
||||
#-------------------------------#
|
||||
x_1 = self.conv1(x)
|
||||
#-------------------------------#
|
||||
# x_2是大的残差边部分
|
||||
#-------------------------------#
|
||||
x_2 = self.conv2(x)
|
||||
|
||||
#-----------------------------------------------#
|
||||
# 主干部分利用残差结构堆叠继续进行特征提取
|
||||
#-----------------------------------------------#
|
||||
x_1 = self.m(x_1)
|
||||
#-----------------------------------------------#
|
||||
# 主干部分和大的残差边部分进行堆叠
|
||||
#-----------------------------------------------#
|
||||
x = torch.cat((x_1, x_2), dim=1)
|
||||
#-----------------------------------------------#
|
||||
# 对堆叠的结果进行卷积的处理
|
||||
#-----------------------------------------------#
|
||||
return self.conv3(x)
|
||||
|
||||
class CSPDarknet(nn.Module):
|
||||
def __init__(self, dep_mul, wid_mul, out_features=("dark3", "dark4", "dark5"), depthwise=False, act="silu",):
|
||||
super().__init__()
|
||||
assert out_features, "please provide output features of Darknet"
|
||||
self.out_features = out_features
|
||||
Conv = DWConv if depthwise else BaseConv
|
||||
|
||||
#-----------------------------------------------#
|
||||
# 输入图片是640, 640, 3
|
||||
# 初始的基本通道是64
|
||||
#-----------------------------------------------#
|
||||
base_channels = int(wid_mul * 64) # 64
|
||||
base_depth = max(round(dep_mul * 3), 1) # 3
|
||||
|
||||
#-----------------------------------------------#
|
||||
# 利用focus网络结构进行特征提取
|
||||
# 640, 640, 3 -> 320, 320, 12 -> 320, 320, 64
|
||||
#-----------------------------------------------#
|
||||
self.stem = Focus(3, base_channels, ksize=3, act=act)
|
||||
|
||||
#-----------------------------------------------#
|
||||
# 完成卷积之后,320, 320, 64 -> 160, 160, 128
|
||||
# 完成CSPlayer之后,160, 160, 128 -> 160, 160, 128
|
||||
#-----------------------------------------------#
|
||||
self.dark2 = nn.Sequential(
|
||||
Conv(base_channels, base_channels * 2, 3, 2, act=act),
|
||||
CSPLayer(base_channels * 2, base_channels * 2, n=base_depth, depthwise=depthwise, act=act),
|
||||
)
|
||||
|
||||
#-----------------------------------------------#
|
||||
# 完成卷积之后,160, 160, 128 -> 80, 80, 256
|
||||
# 完成CSPlayer之后,80, 80, 256 -> 80, 80, 256
|
||||
#-----------------------------------------------#
|
||||
self.dark3 = nn.Sequential(
|
||||
Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
|
||||
CSPLayer(base_channels * 4, base_channels * 4, n=base_depth * 3, depthwise=depthwise, act=act),
|
||||
)
|
||||
|
||||
#-----------------------------------------------#
|
||||
# 完成卷积之后,80, 80, 256 -> 40, 40, 512
|
||||
# 完成CSPlayer之后,40, 40, 512 -> 40, 40, 512
|
||||
#-----------------------------------------------#
|
||||
self.dark4 = nn.Sequential(
|
||||
Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
|
||||
CSPLayer(base_channels * 8, base_channels * 8, n=base_depth * 3, depthwise=depthwise, act=act),
|
||||
)
|
||||
|
||||
#-----------------------------------------------#
|
||||
# 完成卷积之后,40, 40, 512 -> 20, 20, 1024
|
||||
# 完成SPP之后,20, 20, 1024 -> 20, 20, 1024
|
||||
# 完成CSPlayer之后,20, 20, 1024 -> 20, 20, 1024
|
||||
#-----------------------------------------------#
|
||||
self.dark5 = nn.Sequential(
|
||||
Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
|
||||
SPPBottleneck(base_channels * 16, base_channels * 16, activation=act),
|
||||
CSPLayer(base_channels * 16, base_channels * 16, n=base_depth, shortcut=False, depthwise=depthwise, act=act),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
outputs = {}
|
||||
x = self.stem(x)
|
||||
outputs["stem"] = x
|
||||
x = self.dark2(x)
|
||||
outputs["dark2"] = x
|
||||
#-----------------------------------------------#
|
||||
# dark3的输出为80, 80, 256,是一个有效特征层
|
||||
#-----------------------------------------------#
|
||||
x = self.dark3(x)
|
||||
outputs["dark3"] = x
|
||||
#-----------------------------------------------#
|
||||
# dark4的输出为40, 40, 512,是一个有效特征层
|
||||
#-----------------------------------------------#
|
||||
x = self.dark4(x)
|
||||
outputs["dark4"] = x
|
||||
#-----------------------------------------------#
|
||||
# dark5的输出为20, 20, 1024,是一个有效特征层
|
||||
#-----------------------------------------------#
|
||||
x = self.dark5(x)
|
||||
outputs["dark5"] = x
|
||||
return {k: v for k, v in outputs.items() if k in self.out_features}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(CSPDarknet(1, 1))
|
||||
@@ -0,0 +1,247 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Copyright (c) Megvii, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .darknet import BaseConv, CSPDarknet, CSPLayer, DWConv
|
||||
|
||||
|
||||
class YOLOXHead(nn.Module):
|
||||
def __init__(self, num_classes, width = 1.0, in_channels = [256, 512, 1024], act = "silu", depthwise = False,):
|
||||
super().__init__()
|
||||
Conv = DWConv if depthwise else BaseConv
|
||||
|
||||
self.cls_convs = nn.ModuleList()
|
||||
self.reg_convs = nn.ModuleList()
|
||||
self.cls_preds = nn.ModuleList()
|
||||
self.reg_preds = nn.ModuleList()
|
||||
self.obj_preds = nn.ModuleList()
|
||||
self.stems = nn.ModuleList()
|
||||
|
||||
for i in range(len(in_channels)):
|
||||
self.stems.append(BaseConv(in_channels = int(in_channels[i] * width), out_channels = int(256 * width), ksize = 1, stride = 1, act = act))
|
||||
self.cls_convs.append(nn.Sequential(*[
|
||||
Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act),
|
||||
Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act),
|
||||
]))
|
||||
self.cls_preds.append(
|
||||
nn.Conv2d(in_channels = int(256 * width), out_channels = num_classes, kernel_size = 1, stride = 1, padding = 0)
|
||||
)
|
||||
|
||||
|
||||
self.reg_convs.append(nn.Sequential(*[
|
||||
Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act),
|
||||
Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act)
|
||||
]))
|
||||
self.reg_preds.append(
|
||||
nn.Conv2d(in_channels = int(256 * width), out_channels = 4, kernel_size = 1, stride = 1, padding = 0)
|
||||
)
|
||||
self.obj_preds.append(
|
||||
nn.Conv2d(in_channels = int(256 * width), out_channels = 1, kernel_size = 1, stride = 1, padding = 0)
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
#---------------------------------------------------#
|
||||
# inputs输入
|
||||
# P3_out 80, 80, 256
|
||||
# P4_out 40, 40, 512
|
||||
# P5_out 20, 20, 1024
|
||||
#---------------------------------------------------#
|
||||
outputs = []
|
||||
for k, x in enumerate(inputs):
|
||||
#---------------------------------------------------#
|
||||
# 利用1x1卷积进行通道整合
|
||||
#---------------------------------------------------#
|
||||
x = self.stems[k](x)
|
||||
#---------------------------------------------------#
|
||||
# 利用两个卷积标准化激活函数来进行特征提取
|
||||
#---------------------------------------------------#
|
||||
cls_feat = self.cls_convs[k](x)
|
||||
#---------------------------------------------------#
|
||||
# 判断特征点所属的种类
|
||||
# 80, 80, num_classes
|
||||
# 40, 40, num_classes
|
||||
# 20, 20, num_classes
|
||||
#---------------------------------------------------#
|
||||
cls_output = self.cls_preds[k](cls_feat)
|
||||
|
||||
#---------------------------------------------------#
|
||||
# 利用两个卷积标准化激活函数来进行特征提取
|
||||
#---------------------------------------------------#
|
||||
reg_feat = self.reg_convs[k](x)
|
||||
#---------------------------------------------------#
|
||||
# 特征点的回归系数
|
||||
# reg_pred 80, 80, 4
|
||||
# reg_pred 40, 40, 4
|
||||
# reg_pred 20, 20, 4
|
||||
#---------------------------------------------------#
|
||||
reg_output = self.reg_preds[k](reg_feat)
|
||||
#---------------------------------------------------#
|
||||
# 判断特征点是否有对应的物体
|
||||
# obj_pred 80, 80, 1
|
||||
# obj_pred 40, 40, 1
|
||||
# obj_pred 20, 20, 1
|
||||
#---------------------------------------------------#
|
||||
obj_output = self.obj_preds[k](reg_feat)
|
||||
|
||||
output = torch.cat([reg_output, obj_output, cls_output], 1)
|
||||
outputs.append(output)
|
||||
return outputs
|
||||
|
||||
class YOLOPAFPN(nn.Module):
|
||||
def __init__(self, depth = 1.0, width = 1.0, in_features = ("dark3", "dark4", "dark5"), in_channels = [256, 512, 1024], depthwise = False, act = "silu"):
|
||||
super().__init__()
|
||||
Conv = DWConv if depthwise else BaseConv
|
||||
self.backbone = CSPDarknet(depth, width, depthwise = depthwise, act = act)
|
||||
self.in_features = in_features
|
||||
|
||||
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
|
||||
|
||||
#-------------------------------------------#
|
||||
# 20, 20, 1024 -> 20, 20, 512
|
||||
#-------------------------------------------#
|
||||
self.lateral_conv0 = BaseConv(int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act)
|
||||
|
||||
#-------------------------------------------#
|
||||
# 40, 40, 1024 -> 40, 40, 512
|
||||
#-------------------------------------------#
|
||||
self.C3_p4 = CSPLayer(
|
||||
int(2 * in_channels[1] * width),
|
||||
int(in_channels[1] * width),
|
||||
round(3 * depth),
|
||||
False,
|
||||
depthwise = depthwise,
|
||||
act = act,
|
||||
)
|
||||
|
||||
#-------------------------------------------#
|
||||
# 40, 40, 512 -> 40, 40, 256
|
||||
#-------------------------------------------#
|
||||
self.reduce_conv1 = BaseConv(int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act)
|
||||
#-------------------------------------------#
|
||||
# 80, 80, 512 -> 80, 80, 256
|
||||
#-------------------------------------------#
|
||||
self.C3_p3 = CSPLayer(
|
||||
int(2 * in_channels[0] * width),
|
||||
int(in_channels[0] * width),
|
||||
round(3 * depth),
|
||||
False,
|
||||
depthwise = depthwise,
|
||||
act = act,
|
||||
)
|
||||
|
||||
#-------------------------------------------#
|
||||
# 80, 80, 256 -> 40, 40, 256
|
||||
#-------------------------------------------#
|
||||
self.bu_conv2 = Conv(int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act)
|
||||
#-------------------------------------------#
|
||||
# 40, 40, 256 -> 40, 40, 512
|
||||
#-------------------------------------------#
|
||||
self.C3_n3 = CSPLayer(
|
||||
int(2 * in_channels[0] * width),
|
||||
int(in_channels[1] * width),
|
||||
round(3 * depth),
|
||||
False,
|
||||
depthwise = depthwise,
|
||||
act = act,
|
||||
)
|
||||
|
||||
#-------------------------------------------#
|
||||
# 40, 40, 512 -> 20, 20, 512
|
||||
#-------------------------------------------#
|
||||
self.bu_conv1 = Conv(int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act)
|
||||
#-------------------------------------------#
|
||||
# 20, 20, 1024 -> 20, 20, 1024
|
||||
#-------------------------------------------#
|
||||
self.C3_n4 = CSPLayer(
|
||||
int(2 * in_channels[1] * width),
|
||||
int(in_channels[2] * width),
|
||||
round(3 * depth),
|
||||
False,
|
||||
depthwise = depthwise,
|
||||
act = act,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
out_features = self.backbone.forward(input)
|
||||
[feat1, feat2, feat3] = [out_features[f] for f in self.in_features]
|
||||
|
||||
#-------------------------------------------#
|
||||
# 20, 20, 1024 -> 20, 20, 512
|
||||
#-------------------------------------------#
|
||||
P5 = self.lateral_conv0(feat3)
|
||||
#-------------------------------------------#
|
||||
# 20, 20, 512 -> 40, 40, 512
|
||||
#-------------------------------------------#
|
||||
P5_upsample = self.upsample(P5)
|
||||
#-------------------------------------------#
|
||||
# 40, 40, 512 + 40, 40, 512 -> 40, 40, 1024
|
||||
#-------------------------------------------#
|
||||
P5_upsample = torch.cat([P5_upsample, feat2], 1)
|
||||
#-------------------------------------------#
|
||||
# 40, 40, 1024 -> 40, 40, 512
|
||||
#-------------------------------------------#
|
||||
P5_upsample = self.C3_p4(P5_upsample)
|
||||
|
||||
#-------------------------------------------#
|
||||
# 40, 40, 512 -> 40, 40, 256
|
||||
#-------------------------------------------#
|
||||
P4 = self.reduce_conv1(P5_upsample)
|
||||
#-------------------------------------------#
|
||||
# 40, 40, 256 -> 80, 80, 256
|
||||
#-------------------------------------------#
|
||||
P4_upsample = self.upsample(P4)
|
||||
#-------------------------------------------#
|
||||
# 80, 80, 256 + 80, 80, 256 -> 80, 80, 512
|
||||
#-------------------------------------------#
|
||||
P4_upsample = torch.cat([P4_upsample, feat1], 1)
|
||||
#-------------------------------------------#
|
||||
# 80, 80, 512 -> 80, 80, 256
|
||||
#-------------------------------------------#
|
||||
P3_out = self.C3_p3(P4_upsample)
|
||||
|
||||
#-------------------------------------------#
|
||||
# 80, 80, 256 -> 40, 40, 256
|
||||
#-------------------------------------------#
|
||||
P3_downsample = self.bu_conv2(P3_out)
|
||||
#-------------------------------------------#
|
||||
# 40, 40, 256 + 40, 40, 256 -> 40, 40, 512
|
||||
#-------------------------------------------#
|
||||
P3_downsample = torch.cat([P3_downsample, P4], 1)
|
||||
#-------------------------------------------#
|
||||
# 40, 40, 256 -> 40, 40, 512
|
||||
#-------------------------------------------#
|
||||
P4_out = self.C3_n3(P3_downsample)
|
||||
|
||||
#-------------------------------------------#
|
||||
# 40, 40, 512 -> 20, 20, 512
|
||||
#-------------------------------------------#
|
||||
P4_downsample = self.bu_conv1(P4_out)
|
||||
#-------------------------------------------#
|
||||
# 20, 20, 512 + 20, 20, 512 -> 20, 20, 1024
|
||||
#-------------------------------------------#
|
||||
P4_downsample = torch.cat([P4_downsample, P5], 1)
|
||||
#-------------------------------------------#
|
||||
# 20, 20, 1024 -> 20, 20, 1024
|
||||
#-------------------------------------------#
|
||||
P5_out = self.C3_n4(P4_downsample)
|
||||
|
||||
return (P3_out, P4_out, P5_out)
|
||||
|
||||
class YoloBody(nn.Module):
|
||||
def __init__(self, num_classes, phi):
|
||||
super().__init__()
|
||||
depth_dict = {'nano': 0.33, 'tiny': 0.33, 's' : 0.33, 'm' : 0.67, 'l' : 1.00, 'x' : 1.33,}
|
||||
width_dict = {'nano': 0.25, 'tiny': 0.375, 's' : 0.50, 'm' : 0.75, 'l' : 1.00, 'x' : 1.25,}
|
||||
depth, width = depth_dict[phi], width_dict[phi]
|
||||
depthwise = True if phi == 'nano' else False
|
||||
|
||||
self.backbone = YOLOPAFPN(depth, width, depthwise=depthwise)
|
||||
self.head = YOLOXHead(num_classes, width, depthwise=depthwise)
|
||||
|
||||
def forward(self, x):
|
||||
fpn_outs = self.backbone.forward(x)
|
||||
outputs = self.head.forward(fpn_outs)
|
||||
return outputs
|
||||
@@ -0,0 +1,488 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Copyright (c) Megvii, Inc. and its affiliates.
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class IOUloss(nn.Module):
|
||||
def __init__(self, reduction="none", loss_type="iou"):
|
||||
super(IOUloss, self).__init__()
|
||||
self.reduction = reduction
|
||||
self.loss_type = loss_type
|
||||
|
||||
def forward(self, pred, target):
|
||||
assert pred.shape[0] == target.shape[0]
|
||||
|
||||
pred = pred.view(-1, 4)
|
||||
target = target.view(-1, 4)
|
||||
tl = torch.max(
|
||||
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
|
||||
)
|
||||
br = torch.min(
|
||||
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
|
||||
)
|
||||
|
||||
area_p = torch.prod(pred[:, 2:], 1)
|
||||
area_g = torch.prod(target[:, 2:], 1)
|
||||
|
||||
en = (tl < br).type(tl.type()).prod(dim=1)
|
||||
area_i = torch.prod(br - tl, 1) * en
|
||||
area_u = area_p + area_g - area_i
|
||||
iou = (area_i) / (area_u + 1e-16)
|
||||
|
||||
if self.loss_type == "iou":
|
||||
loss = 1 - iou ** 2
|
||||
elif self.loss_type == "giou":
|
||||
c_tl = torch.min(
|
||||
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
|
||||
)
|
||||
c_br = torch.max(
|
||||
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
|
||||
)
|
||||
area_c = torch.prod(c_br - c_tl, 1)
|
||||
giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
|
||||
loss = 1 - giou.clamp(min=-1.0, max=1.0)
|
||||
|
||||
if self.reduction == "mean":
|
||||
loss = loss.mean()
|
||||
elif self.reduction == "sum":
|
||||
loss = loss.sum()
|
||||
|
||||
return loss
|
||||
|
||||
class YOLOLoss(nn.Module):
|
||||
def __init__(self, num_classes, fp16, strides=[8, 16, 32]):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.strides = strides
|
||||
|
||||
self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
|
||||
self.iou_loss = IOUloss(reduction="none")
|
||||
self.grids = [torch.zeros(1)] * len(strides)
|
||||
self.fp16 = fp16
|
||||
|
||||
def forward(self, inputs, labels=None):
|
||||
outputs = []
|
||||
x_shifts = []
|
||||
y_shifts = []
|
||||
expanded_strides = []
|
||||
|
||||
#-----------------------------------------------#
|
||||
# inputs [[batch_size, num_classes + 5, 20, 20]
|
||||
# [batch_size, num_classes + 5, 40, 40]
|
||||
# [batch_size, num_classes + 5, 80, 80]]
|
||||
# outputs [[batch_size, 400, num_classes + 5]
|
||||
# [batch_size, 1600, num_classes + 5]
|
||||
# [batch_size, 6400, num_classes + 5]]
|
||||
# x_shifts [[batch_size, 400]
|
||||
# [batch_size, 1600]
|
||||
# [batch_size, 6400]]
|
||||
#-----------------------------------------------#
|
||||
for k, (stride, output) in enumerate(zip(self.strides, inputs)):
|
||||
output, grid = self.get_output_and_grid(output, k, stride)
|
||||
x_shifts.append(grid[:, :, 0])
|
||||
y_shifts.append(grid[:, :, 1])
|
||||
expanded_strides.append(torch.ones_like(grid[:, :, 0]) * stride)
|
||||
outputs.append(output)
|
||||
|
||||
return self.get_losses(x_shifts, y_shifts, expanded_strides, labels, torch.cat(outputs, 1))
|
||||
|
||||
def get_output_and_grid(self, output, k, stride):
|
||||
grid = self.grids[k]
|
||||
hsize, wsize = output.shape[-2:]
|
||||
if grid.shape[2:4] != output.shape[2:4]:
|
||||
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
|
||||
grid = torch.stack((xv, yv), 2).view(1, hsize, wsize, 2).type(output.type())
|
||||
self.grids[k] = grid
|
||||
grid = grid.view(1, -1, 2)
|
||||
|
||||
output = output.flatten(start_dim=2).permute(0, 2, 1)
|
||||
output[..., :2] = (output[..., :2] + grid.type_as(output)) * stride
|
||||
output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
|
||||
return output, grid
|
||||
|
||||
def get_losses(self, x_shifts, y_shifts, expanded_strides, labels, outputs):
|
||||
#-----------------------------------------------#
|
||||
# [batch, n_anchors_all, 4]
|
||||
#-----------------------------------------------#
|
||||
bbox_preds = outputs[:, :, :4]
|
||||
#-----------------------------------------------#
|
||||
# [batch, n_anchors_all, 1]
|
||||
#-----------------------------------------------#
|
||||
obj_preds = outputs[:, :, 4:5]
|
||||
#-----------------------------------------------#
|
||||
# [batch, n_anchors_all, n_cls]
|
||||
#-----------------------------------------------#
|
||||
cls_preds = outputs[:, :, 5:]
|
||||
|
||||
total_num_anchors = outputs.shape[1]
|
||||
#-----------------------------------------------#
|
||||
# x_shifts [1, n_anchors_all]
|
||||
# y_shifts [1, n_anchors_all]
|
||||
# expanded_strides [1, n_anchors_all]
|
||||
#-----------------------------------------------#
|
||||
x_shifts = torch.cat(x_shifts, 1).type_as(outputs)
|
||||
y_shifts = torch.cat(y_shifts, 1).type_as(outputs)
|
||||
expanded_strides = torch.cat(expanded_strides, 1).type_as(outputs)
|
||||
|
||||
cls_targets = []
|
||||
reg_targets = []
|
||||
obj_targets = []
|
||||
fg_masks = []
|
||||
|
||||
num_fg = 0.0
|
||||
for batch_idx in range(outputs.shape[0]):
|
||||
num_gt = len(labels[batch_idx])
|
||||
if num_gt == 0:
|
||||
cls_target = outputs.new_zeros((0, self.num_classes))
|
||||
reg_target = outputs.new_zeros((0, 4))
|
||||
obj_target = outputs.new_zeros((total_num_anchors, 1))
|
||||
fg_mask = outputs.new_zeros(total_num_anchors).bool()
|
||||
else:
|
||||
#-----------------------------------------------#
|
||||
# gt_bboxes_per_image [num_gt, num_classes]
|
||||
# gt_classes [num_gt]
|
||||
# bboxes_preds_per_image [n_anchors_all, 4]
|
||||
# cls_preds_per_image [n_anchors_all, num_classes]
|
||||
# obj_preds_per_image [n_anchors_all, 1]
|
||||
#-----------------------------------------------#
|
||||
gt_bboxes_per_image = labels[batch_idx][..., :4].type_as(outputs)
|
||||
gt_classes = labels[batch_idx][..., 4].type_as(outputs)
|
||||
bboxes_preds_per_image = bbox_preds[batch_idx]
|
||||
cls_preds_per_image = cls_preds[batch_idx]
|
||||
obj_preds_per_image = obj_preds[batch_idx]
|
||||
|
||||
gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = self.get_assignments(
|
||||
num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image,
|
||||
expanded_strides, x_shifts, y_shifts,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
num_fg += num_fg_img
|
||||
cls_target = F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes).float() * pred_ious_this_matching.unsqueeze(-1)
|
||||
obj_target = fg_mask.unsqueeze(-1)
|
||||
reg_target = gt_bboxes_per_image[matched_gt_inds]
|
||||
cls_targets.append(cls_target)
|
||||
reg_targets.append(reg_target)
|
||||
obj_targets.append(obj_target.type(cls_target.type()))
|
||||
fg_masks.append(fg_mask)
|
||||
|
||||
cls_targets = torch.cat(cls_targets, 0)
|
||||
reg_targets = torch.cat(reg_targets, 0)
|
||||
obj_targets = torch.cat(obj_targets, 0)
|
||||
fg_masks = torch.cat(fg_masks, 0)
|
||||
|
||||
num_fg = max(num_fg, 1)
|
||||
loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum()
|
||||
loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum()
|
||||
loss_cls = (self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum()
|
||||
reg_weight = 5.0
|
||||
loss = reg_weight * loss_iou + loss_obj + loss_cls
|
||||
|
||||
return loss / num_fg
|
||||
|
||||
@torch.no_grad()
|
||||
def get_assignments(self, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image, expanded_strides, x_shifts, y_shifts):
|
||||
#-------------------------------------------------------#
|
||||
# fg_mask [n_anchors_all]
|
||||
# is_in_boxes_and_center [num_gt, len(fg_mask)]
|
||||
#-------------------------------------------------------#
|
||||
fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt)
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# fg_mask [n_anchors_all]
|
||||
# bboxes_preds_per_image [fg_mask, 4]
|
||||
# cls_preds_ [fg_mask, num_classes]
|
||||
# obj_preds_ [fg_mask, 1]
|
||||
#-------------------------------------------------------#
|
||||
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
|
||||
cls_preds_ = cls_preds_per_image[fg_mask]
|
||||
obj_preds_ = obj_preds_per_image[fg_mask]
|
||||
num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# pair_wise_ious [num_gt, fg_mask]
|
||||
#-------------------------------------------------------#
|
||||
pair_wise_ious = self.bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)
|
||||
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# cls_preds_ [num_gt, fg_mask, num_classes]
|
||||
# gt_cls_per_image [num_gt, fg_mask, num_classes]
|
||||
#-------------------------------------------------------#
|
||||
if self.fp16:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
|
||||
gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)
|
||||
pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)
|
||||
else:
|
||||
cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
|
||||
gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)
|
||||
pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)
|
||||
del cls_preds_
|
||||
|
||||
cost = pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center).float()
|
||||
|
||||
num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
|
||||
del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
|
||||
return gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg
|
||||
|
||||
def bboxes_iou(self, bboxes_a, bboxes_b, xyxy=True):
|
||||
if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
|
||||
raise IndexError
|
||||
|
||||
if xyxy:
|
||||
tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
|
||||
br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
|
||||
area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
|
||||
area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
|
||||
else:
|
||||
tl = torch.max(
|
||||
(bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
|
||||
(bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
|
||||
)
|
||||
br = torch.min(
|
||||
(bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
|
||||
(bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
|
||||
)
|
||||
|
||||
area_a = torch.prod(bboxes_a[:, 2:], 1)
|
||||
area_b = torch.prod(bboxes_b[:, 2:], 1)
|
||||
en = (tl < br).type(tl.type()).prod(dim=2)
|
||||
area_i = torch.prod(br - tl, 2) * en
|
||||
return area_i / (area_a[:, None] + area_b - area_i)
|
||||
|
||||
def get_in_boxes_info(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt, center_radius = 2.5):
|
||||
#-------------------------------------------------------#
|
||||
# expanded_strides_per_image [n_anchors_all]
|
||||
# x_centers_per_image [num_gt, n_anchors_all]
|
||||
# x_centers_per_image [num_gt, n_anchors_all]
|
||||
#-------------------------------------------------------#
|
||||
expanded_strides_per_image = expanded_strides[0]
|
||||
x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
|
||||
y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# gt_bboxes_per_image_x [num_gt, n_anchors_all]
|
||||
#-------------------------------------------------------#
|
||||
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
|
||||
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
|
||||
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
|
||||
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# bbox_deltas [num_gt, n_anchors_all, 4]
|
||||
#-------------------------------------------------------#
|
||||
b_l = x_centers_per_image - gt_bboxes_per_image_l
|
||||
b_r = gt_bboxes_per_image_r - x_centers_per_image
|
||||
b_t = y_centers_per_image - gt_bboxes_per_image_t
|
||||
b_b = gt_bboxes_per_image_b - y_centers_per_image
|
||||
bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# is_in_boxes [num_gt, n_anchors_all]
|
||||
# is_in_boxes_all [n_anchors_all]
|
||||
#-------------------------------------------------------#
|
||||
is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
|
||||
is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
|
||||
|
||||
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
|
||||
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
|
||||
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
|
||||
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# center_deltas [num_gt, n_anchors_all, 4]
|
||||
#-------------------------------------------------------#
|
||||
c_l = x_centers_per_image - gt_bboxes_per_image_l
|
||||
c_r = gt_bboxes_per_image_r - x_centers_per_image
|
||||
c_t = y_centers_per_image - gt_bboxes_per_image_t
|
||||
c_b = gt_bboxes_per_image_b - y_centers_per_image
|
||||
center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# is_in_centers [num_gt, n_anchors_all]
|
||||
# is_in_centers_all [n_anchors_all]
|
||||
#-------------------------------------------------------#
|
||||
is_in_centers = center_deltas.min(dim=-1).values > 0.0
|
||||
is_in_centers_all = is_in_centers.sum(dim=0) > 0
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# is_in_boxes_anchor [n_anchors_all]
|
||||
# is_in_boxes_and_center [num_gt, is_in_boxes_anchor]
|
||||
#-------------------------------------------------------#
|
||||
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
|
||||
is_in_boxes_and_center = is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
|
||||
return is_in_boxes_anchor, is_in_boxes_and_center
|
||||
|
||||
def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
|
||||
#-------------------------------------------------------#
|
||||
# cost [num_gt, fg_mask]
|
||||
# pair_wise_ious [num_gt, fg_mask]
|
||||
# gt_classes [num_gt]
|
||||
# fg_mask [n_anchors_all]
|
||||
# matching_matrix [num_gt, fg_mask]
|
||||
#-------------------------------------------------------#
|
||||
matching_matrix = torch.zeros_like(cost)
|
||||
|
||||
#------------------------------------------------------------#
|
||||
# 选取iou最大的n_candidate_k个点
|
||||
# 然后求和,判断应该有多少点用于该框预测
|
||||
# topk_ious [num_gt, n_candidate_k]
|
||||
# dynamic_ks [num_gt]
|
||||
# matching_matrix [num_gt, fg_mask]
|
||||
#------------------------------------------------------------#
|
||||
n_candidate_k = min(10, pair_wise_ious.size(1))
|
||||
topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)
|
||||
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
|
||||
|
||||
for gt_idx in range(num_gt):
|
||||
#------------------------------------------------------------#
|
||||
# 给每个真实框选取最小的动态k个点
|
||||
#------------------------------------------------------------#
|
||||
_, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
|
||||
matching_matrix[gt_idx][pos_idx] = 1.0
|
||||
del topk_ious, dynamic_ks, pos_idx
|
||||
|
||||
#------------------------------------------------------------#
|
||||
# anchor_matching_gt [fg_mask]
|
||||
#------------------------------------------------------------#
|
||||
anchor_matching_gt = matching_matrix.sum(0)
|
||||
if (anchor_matching_gt > 1).sum() > 0:
|
||||
#------------------------------------------------------------#
|
||||
# 当某一个特征点指向多个真实框的时候
|
||||
# 选取cost最小的真实框。
|
||||
#------------------------------------------------------------#
|
||||
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
|
||||
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
|
||||
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
|
||||
#------------------------------------------------------------#
|
||||
# fg_mask_inboxes [fg_mask]
|
||||
# num_fg为正样本的特征点个数
|
||||
#------------------------------------------------------------#
|
||||
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
|
||||
num_fg = fg_mask_inboxes.sum().item()
|
||||
|
||||
#------------------------------------------------------------#
|
||||
# 对fg_mask进行更新
|
||||
#------------------------------------------------------------#
|
||||
fg_mask[fg_mask.clone()] = fg_mask_inboxes
|
||||
|
||||
#------------------------------------------------------------#
|
||||
# 获得特征点对应的物品种类
|
||||
#------------------------------------------------------------#
|
||||
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
|
||||
gt_matched_classes = gt_classes[matched_gt_inds]
|
||||
|
||||
pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes]
|
||||
return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
|
||||
|
||||
def is_parallel(model):
|
||||
# Returns True if model is of type DP or DDP
|
||||
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||
|
||||
def de_parallel(model):
|
||||
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
|
||||
return model.module if is_parallel(model) else model
|
||||
|
||||
def copy_attr(a, b, include=(), exclude=()):
|
||||
# Copy attributes from b to a, options to only include [...] and to exclude [...]
|
||||
for k, v in b.__dict__.items():
|
||||
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
||||
continue
|
||||
else:
|
||||
setattr(a, k, v)
|
||||
|
||||
class ModelEMA:
|
||||
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
||||
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
||||
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
"""
|
||||
|
||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||
# Create EMA
|
||||
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
||||
# if next(model.parameters()).device.type != 'cpu':
|
||||
# self.ema.half() # FP16 EMA
|
||||
self.updates = updates # number of EMA updates
|
||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||||
for p in self.ema.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
def update(self, model):
|
||||
# Update EMA parameters
|
||||
with torch.no_grad():
|
||||
self.updates += 1
|
||||
d = self.decay(self.updates)
|
||||
|
||||
msd = de_parallel(model).state_dict() # model state_dict
|
||||
for k, v in self.ema.state_dict().items():
|
||||
if v.dtype.is_floating_point:
|
||||
v *= d
|
||||
v += (1 - d) * msd[k].detach()
|
||||
|
||||
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
||||
# Update EMA attributes
|
||||
copy_attr(self.ema, model, include, exclude)
|
||||
|
||||
def weights_init(net, init_type='normal', init_gain = 0.02):
|
||||
def init_func(m):
|
||||
classname = m.__class__.__name__
|
||||
if hasattr(m, 'weight') and classname.find('Conv') != -1:
|
||||
if init_type == 'normal':
|
||||
torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
|
||||
elif init_type == 'xavier':
|
||||
torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
|
||||
elif init_type == 'kaiming':
|
||||
torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||
elif init_type == 'orthogonal':
|
||||
torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
|
||||
else:
|
||||
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||
elif classname.find('BatchNorm2d') != -1:
|
||||
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
torch.nn.init.constant_(m.bias.data, 0.0)
|
||||
print('initialize network with %s type' % init_type)
|
||||
net.apply(init_func)
|
||||
|
||||
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
|
||||
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
|
||||
if iters <= warmup_total_iters:
|
||||
# lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
|
||||
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
|
||||
elif iters >= total_iters - no_aug_iter:
|
||||
lr = min_lr
|
||||
else:
|
||||
lr = min_lr + 0.5 * (lr - min_lr) * (
|
||||
1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
|
||||
)
|
||||
return lr
|
||||
|
||||
def step_lr(lr, decay_rate, step_size, iters):
|
||||
if step_size < 1:
|
||||
raise ValueError("step_size must above 1.")
|
||||
n = iters // step_size
|
||||
out_lr = lr * decay_rate ** n
|
||||
return out_lr
|
||||
|
||||
if lr_decay_type == "cos":
|
||||
warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
|
||||
warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
|
||||
no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
|
||||
func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
|
||||
else:
|
||||
decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
|
||||
step_size = total_iters / step_num
|
||||
func = partial(step_lr, lr, decay_rate, step_size)
|
||||
|
||||
return func
|
||||
|
||||
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
|
||||
lr = lr_scheduler_func(epoch)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
@@ -0,0 +1,180 @@
|
||||
# -----------------------------------------------------------------------#
|
||||
# predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能
|
||||
# 整合到了一个py文件中,通过指定mode进行模式的修改。
|
||||
# -----------------------------------------------------------------------#
|
||||
import os
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from yolo import YOLO
|
||||
|
||||
|
||||
def PREDICT(dataPath, resultPath, fileName):
|
||||
yolo = YOLO()
|
||||
# ----------------------------------------------------------------------------------------------------------#
|
||||
# mode用于指定测试的模式:
|
||||
# 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
|
||||
# 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
|
||||
# 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
|
||||
# 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
|
||||
# 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。
|
||||
# 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。
|
||||
# ----------------------------------------------------------------------------------------------------------#
|
||||
mode = "predict"
|
||||
# -------------------------------------------------------------------------#
|
||||
# crop 指定了是否在单张图片预测后对目标进行截取
|
||||
# count 指定了是否进行目标的计数
|
||||
# crop、count仅在mode='predict'时有效
|
||||
# -------------------------------------------------------------------------#
|
||||
crop = False
|
||||
count = False
|
||||
# ----------------------------------------------------------------------------------------------------------#
|
||||
# video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
|
||||
# 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
|
||||
# video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
|
||||
# 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
|
||||
# video_fps 用于保存的视频的fps
|
||||
#
|
||||
# video_path、video_save_path和video_fps仅在mode='video'时有效
|
||||
# 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
|
||||
# ----------------------------------------------------------------------------------------------------------#
|
||||
video_path = 0
|
||||
video_save_path = ""
|
||||
video_fps = 25.0
|
||||
# ----------------------------------------------------------------------------------------------------------#
|
||||
# test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
|
||||
# fps_image_path 用于指定测试的fps图片
|
||||
#
|
||||
# test_interval和fps_image_path仅在mode='fps'有效
|
||||
# ----------------------------------------------------------------------------------------------------------#
|
||||
test_interval = 100
|
||||
fps_image_path = "img/street.jpg"
|
||||
# -------------------------------------------------------------------------#
|
||||
# dir_origin_path 指定了用于检测的图片的文件夹路径
|
||||
# dir_save_path 指定了检测完图片的保存路径
|
||||
#
|
||||
# dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
|
||||
# ----------------- --------------------------------------------------------#
|
||||
dir_origin_path = "yolox-pytorch/img/"
|
||||
dir_save_path = "img_out/"
|
||||
# -------------------------------------------------------------------------#
|
||||
# heatmap_save_path 热力图的保存路径,默认保存在model_data下
|
||||
#
|
||||
# heatmap_save_path仅在mode='heatmap'有效
|
||||
# -------------------------------------------------------------------------#
|
||||
heatmap_save_path = "model_data/heatmap_vision.png"
|
||||
# -------------------------------------------------------------------------#
|
||||
# simplify 使用Simplify onnx
|
||||
# onnx_save_path 指定了onnx的保存路径
|
||||
# -------------------------------------------------------------------------#
|
||||
simplify = True
|
||||
onnx_save_path = "model_data/models.onnx"
|
||||
|
||||
if mode == "predict":
|
||||
'''
|
||||
1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
|
||||
2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
|
||||
3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
|
||||
在原图上利用矩阵的方式进行截取。
|
||||
4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
|
||||
比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
|
||||
'''
|
||||
|
||||
imgPath = os.path.join(dataPath, fileName)
|
||||
resultFilePath = os.path.join(resultPath, fileName)
|
||||
image = Image.open(imgPath)
|
||||
|
||||
r_image, result = yolo.detect_image(image, crop=crop, count=count)
|
||||
r_image.show()
|
||||
|
||||
# 保存图片
|
||||
r_image.save(resultFilePath)
|
||||
|
||||
return result # 返回预测得到的分类名以及得分
|
||||
|
||||
elif mode == "video":
|
||||
capture = cv2.VideoCapture(video_path)
|
||||
if video_save_path != "":
|
||||
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
||||
size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||||
out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
|
||||
|
||||
ref, frame = capture.read()
|
||||
if not ref:
|
||||
raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
|
||||
|
||||
fps = 0.0
|
||||
while (True):
|
||||
t1 = time.time()
|
||||
# 读取某一帧
|
||||
ref, frame = capture.read()
|
||||
if not ref:
|
||||
break
|
||||
# 格式转变,BGRtoRGB
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
# 转变成Image
|
||||
frame = Image.fromarray(np.uint8(frame))
|
||||
# 进行检测
|
||||
frame = np.array(yolo.detect_image(frame))
|
||||
# RGBtoBGR满足opencv显示格式
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||
|
||||
fps = (fps + (1. / (time.time() - t1))) / 2
|
||||
print("fps= %.2f" % (fps))
|
||||
frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||||
|
||||
cv2.imshow("video", frame)
|
||||
c = cv2.waitKey(1) & 0xff
|
||||
if video_save_path != "":
|
||||
out.write(frame)
|
||||
|
||||
if c == 27:
|
||||
capture.release()
|
||||
break
|
||||
|
||||
print("Video Detection Done!")
|
||||
capture.release()
|
||||
if video_save_path != "":
|
||||
print("Save processed video to the path :" + video_save_path)
|
||||
out.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
elif mode == "fps":
|
||||
img = Image.open(fps_image_path)
|
||||
tact_time = yolo.get_FPS(img, test_interval)
|
||||
print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1')
|
||||
|
||||
elif mode == "dir_predict":
|
||||
|
||||
img_names = os.listdir(dir_origin_path)
|
||||
for img_name in tqdm(img_names):
|
||||
if img_name.lower().endswith(
|
||||
('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
|
||||
image_path = os.path.join(dir_origin_path, img_name)
|
||||
image = Image.open(image_path)
|
||||
r_image = yolo.detect_image(image)
|
||||
if not os.path.exists(dir_save_path):
|
||||
os.makedirs(dir_save_path)
|
||||
r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
|
||||
|
||||
elif mode == "heatmap":
|
||||
while True:
|
||||
img = input('Input image filename:')
|
||||
try:
|
||||
image = Image.open(img)
|
||||
except:
|
||||
print('Open Error! Try again!')
|
||||
continue
|
||||
else:
|
||||
yolo.detect_heatmap(image, heatmap_save_path)
|
||||
|
||||
elif mode == "export_onnx":
|
||||
yolo.convert_to_onnx(simplify, onnx_save_path)
|
||||
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.")
|
||||
@@ -0,0 +1,9 @@
|
||||
'''
|
||||
Author: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
|
||||
Date: 2023-04-01 21:34:17
|
||||
LastEditors: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
|
||||
LastEditTime: 2023-04-01 21:41:27
|
||||
FilePath: /luyuetong-data/yolox-pytorch/quanju.py
|
||||
Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
|
||||
'''
|
||||
mingzi = ''
|
||||
@@ -0,0 +1,9 @@
|
||||
scipy==1.2.1
|
||||
numpy==1.17.0
|
||||
matplotlib==3.1.2
|
||||
opencv_python==4.1.2.30
|
||||
torch==1.2.0
|
||||
torchvision==0.4.0
|
||||
tqdm==4.60.0
|
||||
Pillow==8.2.0
|
||||
h5py==2.10.0
|
||||
|
After Width: | Height: | Size: 53 KiB |
|
After Width: | Height: | Size: 46 KiB |
|
After Width: | Height: | Size: 69 KiB |
|
After Width: | Height: | Size: 180 KiB |
|
After Width: | Height: | Size: 152 KiB |
|
After Width: | Height: | Size: 152 KiB |
@@ -0,0 +1,31 @@
|
||||
#--------------------------------------------#
|
||||
# 该部分代码用于看网络结构
|
||||
#--------------------------------------------#
|
||||
import torch
|
||||
from thop import clever_format, profile
|
||||
from torchsummary import summary
|
||||
|
||||
from nets.yolo import YoloBody
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_shape = [640, 640]
|
||||
num_classes = 80
|
||||
phi = 'l'
|
||||
|
||||
# 需要使用device来指定网络在GPU还是CPU运行
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
m = YoloBody(num_classes, phi).to(device)
|
||||
summary(m, (3, input_shape[0], input_shape[1]))
|
||||
|
||||
dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)
|
||||
flops, params = profile(m.to(device), (dummy_input, ), verbose=False)
|
||||
#--------------------------------------------------------#
|
||||
# flops * 2是因为profile没有将卷积作为两个operations
|
||||
# 有些论文将卷积算乘法、加法两个operations。此时乘2
|
||||
# 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2
|
||||
# 本代码选择乘2,参考YOLOX。
|
||||
#--------------------------------------------------------#
|
||||
flops = flops * 2
|
||||
flops, params = clever_format([flops, params], "%.3f")
|
||||
print('Total GFLOPS: %s' % (flops))
|
||||
print('Total params: %s' % (params))
|
||||
@@ -0,0 +1,15 @@
|
||||
from flask import Flask, jsonify, request, make_response, render_template
|
||||
from PIL import Image
|
||||
from yolo import YOLO
|
||||
from predict import PREDICT
|
||||
import calendar,time,os
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
# 用np.load加载.npy文件
|
||||
data = np.load('E:\Temp/train_64/train_64.npy')
|
||||
|
||||
# 查看数组的形状
|
||||
print(data.shape)
|
||||
@@ -0,0 +1,543 @@
|
||||
# -------------------------------------#
|
||||
# 对数据集进行训练
|
||||
# -------------------------------------#
|
||||
import datetime
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from nets.yolo import YoloBody
|
||||
from nets.yolo_training import (ModelEMA, YOLOLoss, get_lr_scheduler,
|
||||
set_optimizer_lr, weights_init)
|
||||
from utils.callbacks import LossHistory, EvalCallback
|
||||
from utils.dataloader import YoloDataset, yolo_dataset_collate
|
||||
from utils.utils import get_classes, show_config
|
||||
from utils.utils_fit import fit_one_epoch
|
||||
|
||||
'''
|
||||
训练自己的目标检测模型一定需要注意以下几点:
|
||||
1、训练前仔细检查自己的格式是否满足要求,该库要求数据集格式为VOC格式,需要准备好的内容有输入图片和标签
|
||||
输入图片为.jpg图片,无需固定大小,传入训练前会自动进行resize。
|
||||
灰度图会自动转成RGB图片进行训练,无需自己修改。
|
||||
输入图片如果后缀非jpg,需要自己批量转成jpg后再开始训练。
|
||||
|
||||
标签为.xml格式,文件中会有需要检测的目标信息,标签文件和输入图片文件相对应。
|
||||
|
||||
2、损失值的大小用于判断是否收敛,比较重要的是有收敛的趋势,即验证集损失不断下降,如果验证集损失基本上不改变的话,模型基本上就收敛了。
|
||||
损失值的具体大小并没有什么意义,大和小只在于损失的计算方式,并不是接近于0才好。如果想要让损失好看点,可以直接到对应的损失函数里面除上10000。
|
||||
训练过程中的损失值会保存在logs文件夹下的loss_%Y_%m_%d_%H_%M_%S文件夹中
|
||||
|
||||
3、训练好的权值文件保存在logs文件夹中,每个训练世代(Epoch)包含若干训练步长(Step),每个训练步长(Step)进行一次梯度下降。
|
||||
如果只是训练了几个Step是不会保存的,Epoch和Step的概念要捋清楚一下。
|
||||
'''
|
||||
if __name__ == "__main__":
|
||||
# ---------------------------------#
|
||||
# Cuda 是否使用Cuda
|
||||
# 没有GPU可以设置成False
|
||||
# ---------------------------------#
|
||||
Cuda = True
|
||||
# ---------------------------------------------------------------------#
|
||||
# distributed 用于指定是否使用单机多卡分布式运行
|
||||
# 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。
|
||||
# Windows系统下默认使用DP模式调用所有显卡,不支持DDP。
|
||||
# DP模式:
|
||||
# 设置 distributed = False
|
||||
# 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python train.py
|
||||
# DDP模式:
|
||||
# 设置 distributed = True
|
||||
# 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py
|
||||
# ---------------------------------------------------------------------#
|
||||
distributed = False
|
||||
# ---------------------------------------------------------------------#
|
||||
# sync_bn 是否使用sync_bn,DDP模式多卡可用
|
||||
# ---------------------------------------------------------------------#
|
||||
sync_bn = False
|
||||
# ---------------------------------------------------------------------#
|
||||
# fp16 是否使用混合精度训练
|
||||
# 可减少约一半的显存、需要pytorch1.7.1以上
|
||||
# ---------------------------------------------------------------------#
|
||||
fp16 = False
|
||||
# ---------------------------------------------------------------------#
|
||||
# classes_path 指向model_data下的txt,与自己训练的数据集相关
|
||||
# 训练前一定要修改classes_path,使其对应自己的数据集
|
||||
# ---------------------------------------------------------------------#
|
||||
classes_path = 'model_data/voc_classes.txt'
|
||||
# ----------------------------------------------------------------------------------------------------------------------------#
|
||||
# 权值文件的下载请看README,可以通过网盘下载。模型的 预训练权重 对不同数据集是通用的,因为特征是通用的。
|
||||
# 模型的 预训练权重 比较重要的部分是 主干特征提取网络的权值部分,用于进行特征提取。
|
||||
# 预训练权重对于99%的情况都必须要用,不用的话主干部分的权值太过随机,特征提取效果不明显,网络训练的结果也不会好
|
||||
#
|
||||
# 如果训练过程中存在中断训练的操作,可以将model_path设置成logs文件夹下的权值文件,将已经训练了一部分的权值再次载入。
|
||||
# 同时修改下方的 冻结阶段 或者 解冻阶段 的参数,来保证模型epoch的连续性。
|
||||
#
|
||||
# 当model_path = ''的时候不加载整个模型的权值。
|
||||
#
|
||||
# 此处使用的是整个模型的权重,因此是在train.py进行加载的。
|
||||
# 如果想要让模型从0开始训练,则设置model_path = '',下面的Freeze_Train = Fasle,此时从0开始训练,且没有冻结主干的过程。
|
||||
#
|
||||
# 一般来讲,网络从0开始的训练效果会很差,因为权值太过随机,特征提取效果不明显,因此非常、非常、非常不建议大家从0开始训练!
|
||||
# 从0开始训练有两个方案:
|
||||
# 1、得益于Mosaic数据增强方法强大的数据增强能力,将UnFreeze_Epoch设置的较大(300及以上)、batch较大(16及以上)、数据较多(万以上)的情况下,
|
||||
# 可以设置mosaic=True,直接随机初始化参数开始训练,但得到的效果仍然不如有预训练的情况。(像COCO这样的大数据集可以这样做)
|
||||
# 2、了解imagenet数据集,首先训练分类模型,获得网络的主干部分权值,分类模型的 主干部分 和该模型通用,基于此进行训练。
|
||||
# ----------------------------------------------------------------------------------------------------------------------------#
|
||||
model_path = 'model_data/yolox_s.pth' # ..
|
||||
# ------------------------------------------------------#
|
||||
# input_shape 输入的shape大小,一定要是32的倍数
|
||||
# ------------------------------------------------------#
|
||||
input_shape = [640, 640]
|
||||
# ------------------------------------------------------#
|
||||
# 所使用的YoloX的版本。nano、tiny、s、m、l、x
|
||||
# ------------------------------------------------------#
|
||||
phi = 's' # ..
|
||||
# ------------------------------------------------------------------#
|
||||
# mosaic 马赛克数据增强。
|
||||
# mosaic_prob 每个step有多少概率使用mosaic数据增强,默认50%。
|
||||
#
|
||||
# mixup 是否使用mixup数据增强,仅在mosaic=True时有效。
|
||||
# 只会对mosaic增强后的图片进行mixup的处理。
|
||||
# mixup_prob 有多少概率在mosaic后使用mixup数据增强,默认50%。
|
||||
# 总的mixup概率为mosaic_prob * mixup_prob。
|
||||
#
|
||||
# special_aug_ratio 参考YoloX,由于Mosaic生成的训练图片,远远脱离自然图片的真实分布。
|
||||
# 当mosaic=True时,本代码会在special_aug_ratio范围内开启mosaic。
|
||||
# 默认为前70%个epoch,100个世代会开启70个世代。
|
||||
#
|
||||
# 余弦退火算法的参数放到下面的lr_decay_type中设置
|
||||
# ------------------------------------------------------------------#
|
||||
mosaic = True
|
||||
mosaic_prob = 0.5
|
||||
mixup = True
|
||||
mixup_prob = 0.5
|
||||
special_aug_ratio = 0.7
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------------#
|
||||
# 训练分为两个阶段,分别是冻结阶段和解冻阶段。设置冻结阶段是为了满足机器性能不足的同学的训练需求。
|
||||
# 冻结训练需要的显存较小,显卡非常差的情况下,可设置Freeze_Epoch等于UnFreeze_Epoch,Freeze_Train = True,此时仅仅进行冻结训练。
|
||||
#
|
||||
# 在此提供若干参数设置建议,各位训练者根据自己的需求进行灵活调整:
|
||||
# (一)从整个模型的预训练权重开始训练:
|
||||
# Adam:
|
||||
# Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 100,Freeze_Train = True,optimizer_type = 'adam',Init_lr = 1e-3,weight_decay = 0。(冻结)
|
||||
# Init_Epoch = 0,UnFreeze_Epoch = 100,Freeze_Train = False,optimizer_type = 'adam',Init_lr = 1e-3,weight_decay = 0。(不冻结)
|
||||
# SGD:
|
||||
# Init_Epoch = 0,Freeze_Epoch = 50,UnFreeze_Epoch = 300,Freeze_Train = True,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 5e-4。(冻结)
|
||||
# Init_Epoch = 0,UnFreeze_Epoch = 300,Freeze_Train = False,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 5e-4。(不冻结)
|
||||
# 其中:UnFreeze_Epoch可以在100-300之间调整。
|
||||
# (二)从0开始训练:
|
||||
# Init_Epoch = 0,UnFreeze_Epoch >= 300,Unfreeze_batch_size >= 16,Freeze_Train = False(不冻结训练)
|
||||
# 其中:UnFreeze_Epoch尽量不小于300。optimizer_type = 'sgd',Init_lr = 1e-2,mosaic = True。
|
||||
# (三)batch_size的设置:
|
||||
# 在显卡能够接受的范围内,以大为好。显存不足与数据集大小无关,提示显存不足(OOM或者CUDA out of memory)请调小batch_size。
|
||||
# 受到BatchNorm层影响,batch_size最小为2,不能为1。
|
||||
# 正常情况下Freeze_batch_size建议为Unfreeze_batch_size的1-2倍。不建议设置的差距过大,因为关系到学习率的自动调整。
|
||||
# ----------------------------------------------------------------------------------------------------------------------------#
|
||||
# ------------------------------------------------------------------#
|
||||
# 冻结阶段训练参数
|
||||
# 此时模型的主干被冻结了,特征提取网络不发生改变
|
||||
# 占用的显存较小,仅对网络进行微调
|
||||
# Init_Epoch 模型当前开始的训练世代,其值可以大于Freeze_Epoch,如设置:
|
||||
# Init_Epoch = 60、Freeze_Epoch = 50、UnFreeze_Epoch = 100
|
||||
# 会跳过冻结阶段,直接从60代开始,并调整对应的学习率。
|
||||
# (断点续练时使用)
|
||||
# Freeze_Epoch 模型冻结训练的Freeze_Epoch
|
||||
# (当Freeze_Train=False时失效)
|
||||
# Freeze_batch_size 模型冻结训练的batch_size
|
||||
# (当Freeze_Train=False时失效)
|
||||
# ------------------------------------------------------------------#
|
||||
Init_Epoch = 0 # ..
|
||||
Freeze_Epoch = 50 # ..
|
||||
Freeze_batch_size = 16 # ..
|
||||
# ------------------------------------------------------------------#
|
||||
# 解冻阶段训练参数
|
||||
# 此时模型的主干不被冻结了,特征提取网络会发生改变
|
||||
# 占用的显存较大,网络所有的参数都会发生改变
|
||||
# UnFreeze_Epoch 模型总共训练的epoch
|
||||
# SGD需要更长的时间收敛,因此设置较大的UnFreeze_Epoch
|
||||
# Adam可以使用相对较小的UnFreeze_Epoch
|
||||
# Unfreeze_batch_size 模型在解冻后的batch_size
|
||||
# ------------------------------------------------------------------#
|
||||
UnFreeze_Epoch = 300 # ..
|
||||
Unfreeze_batch_size = 8 # ..
|
||||
# ------------------------------------------------------------------#
|
||||
# Freeze_Train 是否进行冻结训练
|
||||
# 默认先冻结主干训练后解冻训练。
|
||||
# ------------------------------------------------------------------#
|
||||
Freeze_Train = True # ..
|
||||
|
||||
# ------------------------------------------------------------------#
|
||||
# 其它训练参数:学习率、优化器、学习率下降有关
|
||||
# ------------------------------------------------------------------#
|
||||
# ------------------------------------------------------------------#
|
||||
# Init_lr 模型的最大学习率
|
||||
# Min_lr 模型的最小学习率,默认为最大学习率的0.01
|
||||
# ------------------------------------------------------------------#
|
||||
Init_lr = 1e-2
|
||||
Min_lr = Init_lr * 0.01
|
||||
# ------------------------------------------------------------------#
|
||||
# optimizer_type 使用到的优化器种类,可选的有adam、sgd
|
||||
# 当使用Adam优化器时建议设置 Init_lr=1e-3
|
||||
# 当使用SGD优化器时建议设置 Init_lr=1e-2
|
||||
# momentum 优化器内部使用到的momentum参数
|
||||
# weight_decay 权值衰减,可防止过拟合
|
||||
# adam会导致weight_decay错误,使用adam时建议设置为0。
|
||||
# ------------------------------------------------------------------#
|
||||
optimizer_type = "sgd" #
|
||||
momentum = 0.937
|
||||
weight_decay = 5e-4
|
||||
# ------------------------------------------------------------------#
|
||||
# lr_decay_type 使用到的学习率下降方式,可选的有step、cos
|
||||
# ------------------------------------------------------------------#
|
||||
lr_decay_type = "cos"
|
||||
# ------------------------------------------------------------------#
|
||||
# save_period 多少个epoch保存一次权值
|
||||
# ------------------------------------------------------------------#
|
||||
save_period = 10
|
||||
# ------------------------------------------------------------------#
|
||||
# save_dir 权值与日志文件保存的文件夹
|
||||
# ------------------------------------------------------------------#
|
||||
save_dir = 'logs'
|
||||
# ------------------------------------------------------------------#
|
||||
# eval_flag 是否在训练时进行评估,评估对象为验证集
|
||||
# 安装pycocotools库后,评估体验更佳。
|
||||
# eval_period 代表多少个epoch评估一次,不建议频繁的评估
|
||||
# 评估需要消耗较多的时间,频繁评估会导致训练非常慢
|
||||
# 此处获得的mAP会与get_map.py获得的会有所不同,原因有二:
|
||||
# (一)此处获得的mAP为验证集的mAP。
|
||||
# (二)此处设置评估参数较为保守,目的是加快评估速度。
|
||||
# ------------------------------------------------------------------#
|
||||
eval_flag = True
|
||||
eval_period = 10
|
||||
# ------------------------------------------------------------------#
|
||||
# num_workers 用于设置是否使用多线程读取数据
|
||||
# 开启后会加快数据读取速度,但是会占用更多内存
|
||||
# 内存较小的电脑可以设置为2或者0
|
||||
# ------------------------------------------------------------------#
|
||||
num_workers = 4
|
||||
|
||||
# ----------------------------------------------------#
|
||||
# 获得图片路径和标签
|
||||
# ----------------------------------------------------#
|
||||
train_annotation_path = '2007_train.txt'
|
||||
val_annotation_path = '2007_val.txt'
|
||||
|
||||
# ------------------------------------------------------#
|
||||
# 设置用到的显卡
|
||||
# ------------------------------------------------------#
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
if distributed:
|
||||
dist.init_process_group(backend="nccl")
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
rank = int(os.environ["RANK"])
|
||||
device = torch.device("cuda", local_rank)
|
||||
if local_rank == 0:
|
||||
print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...")
|
||||
print("Gpu Device Count : ", ngpus_per_node)
|
||||
else:
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
local_rank = 0
|
||||
rank = 0
|
||||
|
||||
# ----------------------------------------------------#
|
||||
# 获取classes和anchor
|
||||
# ----------------------------------------------------#
|
||||
class_names, num_classes = get_classes(classes_path)
|
||||
|
||||
# ------------------------------------------------------#
|
||||
# 创建yolo模型
|
||||
# ------------------------------------------------------#
|
||||
model = YoloBody(num_classes, phi)
|
||||
weights_init(model)
|
||||
if model_path != '':
|
||||
# ------------------------------------------------------#
|
||||
# 权值文件请看README,百度网盘下载
|
||||
# ------------------------------------------------------#
|
||||
if local_rank == 0:
|
||||
print('Load weights {}.'.format(model_path))
|
||||
|
||||
# ------------------------------------------------------#
|
||||
# 根据预训练权重的Key和模型的Key进行加载
|
||||
# ------------------------------------------------------#
|
||||
model_dict = model.state_dict()
|
||||
pretrained_dict = torch.load(model_path, map_location=device)
|
||||
load_key, no_load_key, temp_dict = [], [], {}
|
||||
for k, v in pretrained_dict.items():
|
||||
if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
|
||||
temp_dict[k] = v
|
||||
load_key.append(k)
|
||||
else:
|
||||
no_load_key.append(k)
|
||||
model_dict.update(temp_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
# ------------------------------------------------------#
|
||||
# 显示没有匹配上的Key
|
||||
# ------------------------------------------------------#
|
||||
if local_rank == 0:
|
||||
print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key))
|
||||
print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key))
|
||||
print("\n\033[1;33;44m温馨提示,head部分没有载入是正常现象,Backbone部分没有载入是错误的。\033[0m")
|
||||
|
||||
# ----------------------#
|
||||
# 获得损失函数
|
||||
# ----------------------#
|
||||
yolo_loss = YOLOLoss(num_classes, fp16)
|
||||
# ----------------------#
|
||||
# 记录Loss
|
||||
# ----------------------#
|
||||
if local_rank == 0:
|
||||
time_str = datetime.datetime.strftime(datetime.datetime.now(), '%Y_%m_%d_%H_%M_%S')
|
||||
log_dir = os.path.join(save_dir, "loss_" + str(time_str))
|
||||
loss_history = LossHistory(log_dir, model, input_shape=input_shape)
|
||||
else:
|
||||
loss_history = None
|
||||
|
||||
# ------------------------------------------------------------------#
|
||||
# torch 1.2不支持amp,建议使用torch 1.7.1及以上正确使用fp16
|
||||
# 因此torch1.2这里显示"could not be resolve"
|
||||
# ------------------------------------------------------------------#
|
||||
if fp16:
|
||||
from torch.cuda.amp import GradScaler as GradScaler
|
||||
|
||||
scaler = GradScaler()
|
||||
else:
|
||||
scaler = None
|
||||
|
||||
model_train = model.train()
|
||||
# ----------------------------#
|
||||
# 多卡同步Bn
|
||||
# ----------------------------#
|
||||
if sync_bn and ngpus_per_node > 1 and distributed:
|
||||
model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train)
|
||||
elif sync_bn:
|
||||
print("Sync_bn is not support in one gpu or not distributed.")
|
||||
|
||||
if Cuda:
|
||||
if distributed:
|
||||
# ----------------------------#
|
||||
# 多卡平行运行
|
||||
# ----------------------------#
|
||||
model_train = model_train.cuda(local_rank)
|
||||
model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank],
|
||||
find_unused_parameters=True)
|
||||
else:
|
||||
model_train = torch.nn.DataParallel(model)
|
||||
cudnn.benchmark = True
|
||||
model_train = model_train.cuda()
|
||||
|
||||
# ----------------------------#
|
||||
# 权值平滑
|
||||
# ----------------------------#
|
||||
ema = ModelEMA(model_train)
|
||||
|
||||
# ---------------------------#
|
||||
# 读取数据集对应的txt
|
||||
# ---------------------------#
|
||||
with open(train_annotation_path, encoding='utf-8') as f:
|
||||
train_lines = f.readlines()
|
||||
with open(val_annotation_path, encoding='utf-8') as f:
|
||||
val_lines = f.readlines()
|
||||
num_train = len(train_lines)
|
||||
num_val = len(val_lines)
|
||||
|
||||
if local_rank == 0:
|
||||
show_config(
|
||||
classes_path=classes_path, model_path=model_path, input_shape=input_shape, \
|
||||
Init_Epoch=Init_Epoch, Freeze_Epoch=Freeze_Epoch, UnFreeze_Epoch=UnFreeze_Epoch,
|
||||
Freeze_batch_size=Freeze_batch_size, Unfreeze_batch_size=Unfreeze_batch_size, Freeze_Train=Freeze_Train, \
|
||||
Init_lr=Init_lr, Min_lr=Min_lr, optimizer_type=optimizer_type, momentum=momentum,
|
||||
lr_decay_type=lr_decay_type, \
|
||||
save_period=save_period, save_dir=save_dir, num_workers=num_workers, num_train=num_train, num_val=num_val
|
||||
)
|
||||
# ---------------------------------------------------------#
|
||||
# 总训练世代指的是遍历全部数据的总次数
|
||||
# 总训练步长指的是梯度下降的总次数
|
||||
# 每个训练世代包含若干训练步长,每个训练步长进行一次梯度下降。
|
||||
# 此处仅建议最低训练世代,上不封顶,计算时只考虑了解冻部分
|
||||
# ----------------------------------------------------------#
|
||||
wanted_step = 5e4 if optimizer_type == "sgd" else 1.5e4
|
||||
total_step = num_train // Unfreeze_batch_size * UnFreeze_Epoch
|
||||
if total_step <= wanted_step:
|
||||
if num_train // Unfreeze_batch_size == 0:
|
||||
raise ValueError('数据集过小,无法进行训练,请扩充数据集。')
|
||||
wanted_epoch = wanted_step // (num_train // Unfreeze_batch_size) + 1
|
||||
print("\n\033[1;33;44m[Warning] 使用%s优化器时,建议将训练总步长设置到%d以上。\033[0m" % (
|
||||
optimizer_type, wanted_step))
|
||||
print(
|
||||
"\033[1;33;44m[Warning] 本次运行的总训练数据量为%d,Unfreeze_batch_size为%d,共训练%d个Epoch,计算出总训练步长为%d。\033[0m" % (
|
||||
num_train, Unfreeze_batch_size, UnFreeze_Epoch, total_step))
|
||||
print("\033[1;33;44m[Warning] 由于总训练步长为%d,小于建议总步长%d,建议设置总世代为%d。\033[0m" % (
|
||||
total_step, wanted_step, wanted_epoch))
|
||||
|
||||
# ------------------------------------------------------#
|
||||
# 主干特征提取网络特征通用,冻结训练可以加快训练速度
|
||||
# 也可以在训练初期防止权值被破坏。
|
||||
# Init_Epoch为起始世代
|
||||
# Freeze_Epoch为冻结训练的世代
|
||||
# UnFreeze_Epoch总训练世代
|
||||
# 提示OOM或者显存不足请调小Batch_size
|
||||
# ------------------------------------------------------#
|
||||
if True:
|
||||
UnFreeze_flag = False
|
||||
# ------------------------------------#
|
||||
# 冻结一定部分训练
|
||||
# ------------------------------------#
|
||||
if Freeze_Train:
|
||||
for param in model.backbone.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# -------------------------------------------------------------------#
|
||||
# 如果不冻结训练的话,直接设置batch_size为Unfreeze_batch_size
|
||||
# -------------------------------------------------------------------#
|
||||
batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size
|
||||
|
||||
# -------------------------------------------------------------------#
|
||||
# 判断当前batch_size,自适应调整学习率
|
||||
# -------------------------------------------------------------------#
|
||||
nbs = 64
|
||||
lr_limit_max = 1e-3 if optimizer_type == 'adam' else 5e-2
|
||||
lr_limit_min = 3e-4 if optimizer_type == 'adam' else 5e-4
|
||||
Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
|
||||
Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
|
||||
|
||||
# ---------------------------------------#
|
||||
# 根据optimizer_type选择优化器
|
||||
# ---------------------------------------#
|
||||
pg0, pg1, pg2 = [], [], []
|
||||
for k, v in model.named_modules():
|
||||
if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
|
||||
pg2.append(v.bias)
|
||||
if isinstance(v, nn.BatchNorm2d) or "bn" in k:
|
||||
pg0.append(v.weight)
|
||||
elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
|
||||
pg1.append(v.weight)
|
||||
optimizer = {
|
||||
'adam': optim.Adam(pg0, Init_lr_fit, betas=(momentum, 0.999)),
|
||||
'sgd': optim.SGD(pg0, Init_lr_fit, momentum=momentum, nesterov=True)
|
||||
}[optimizer_type]
|
||||
optimizer.add_param_group({"params": pg1, "weight_decay": weight_decay})
|
||||
optimizer.add_param_group({"params": pg2})
|
||||
|
||||
# ---------------------------------------#
|
||||
# 获得学习率下降的公式
|
||||
# ---------------------------------------#
|
||||
lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
|
||||
|
||||
# ---------------------------------------#
|
||||
# 判断每一个世代的长度
|
||||
# ---------------------------------------#
|
||||
epoch_step = num_train // batch_size
|
||||
epoch_step_val = num_val // batch_size
|
||||
|
||||
if epoch_step == 0 or epoch_step_val == 0:
|
||||
raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
|
||||
|
||||
if ema:
|
||||
ema.updates = epoch_step * Init_Epoch
|
||||
|
||||
# ---------------------------------------#
|
||||
# 构建数据集加载器。
|
||||
# ---------------------------------------#
|
||||
train_dataset = YoloDataset(train_lines, input_shape, num_classes, epoch_length=UnFreeze_Epoch, \
|
||||
mosaic=mosaic, mixup=mixup, mosaic_prob=mosaic_prob, mixup_prob=mixup_prob,
|
||||
train=True, special_aug_ratio=special_aug_ratio)
|
||||
val_dataset = YoloDataset(val_lines, input_shape, num_classes, epoch_length=UnFreeze_Epoch, \
|
||||
mosaic=False, mixup=False, mosaic_prob=0, mixup_prob=0, train=False,
|
||||
special_aug_ratio=0)
|
||||
|
||||
if distributed:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, )
|
||||
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, )
|
||||
batch_size = batch_size // ngpus_per_node
|
||||
shuffle = False
|
||||
else:
|
||||
train_sampler = None
|
||||
val_sampler = None
|
||||
shuffle = True
|
||||
|
||||
gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True, collate_fn=yolo_dataset_collate, sampler=train_sampler)
|
||||
gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler)
|
||||
|
||||
# ----------------------#
|
||||
# 记录eval的map曲线
|
||||
# ----------------------#
|
||||
if local_rank == 0:
|
||||
eval_callback = EvalCallback(model, input_shape, class_names, num_classes, val_lines, log_dir, Cuda, \
|
||||
eval_flag=eval_flag, period=eval_period)
|
||||
else:
|
||||
eval_callback = None
|
||||
|
||||
# ---------------------------------------#
|
||||
# 开始模型训练
|
||||
# ---------------------------------------#
|
||||
for epoch in range(Init_Epoch, UnFreeze_Epoch):
|
||||
# ---------------------------------------#
|
||||
# 如果模型有冻结学习部分
|
||||
# 则解冻,并设置参数
|
||||
# ---------------------------------------#
|
||||
if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train:
|
||||
batch_size = Unfreeze_batch_size
|
||||
|
||||
# -------------------------------------------------------------------#
|
||||
# 判断当前batch_size,自适应调整学习率
|
||||
# -------------------------------------------------------------------#
|
||||
nbs = 64
|
||||
lr_limit_max = 1e-3 if optimizer_type == 'adam' else 5e-2
|
||||
lr_limit_min = 3e-4 if optimizer_type == 'adam' else 5e-4
|
||||
Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
|
||||
Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
|
||||
# ---------------------------------------#
|
||||
# 获得学习率下降的公式
|
||||
# ---------------------------------------#
|
||||
lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
|
||||
|
||||
for param in model.backbone.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
epoch_step = num_train // batch_size
|
||||
epoch_step_val = num_val // batch_size
|
||||
|
||||
if epoch_step == 0 or epoch_step_val == 0:
|
||||
raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
|
||||
|
||||
if distributed:
|
||||
batch_size = batch_size // ngpus_per_node
|
||||
|
||||
if ema:
|
||||
ema.updates = epoch_step * epoch
|
||||
|
||||
gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True, collate_fn=yolo_dataset_collate, sampler=train_sampler)
|
||||
gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler)
|
||||
|
||||
UnFreeze_flag = True
|
||||
|
||||
gen.dataset.epoch_now = epoch
|
||||
gen_val.dataset.epoch_now = epoch
|
||||
|
||||
if distributed:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
|
||||
|
||||
fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step,
|
||||
epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir,
|
||||
local_rank)
|
||||
|
||||
if distributed:
|
||||
dist.barrier()
|
||||
|
||||
if local_rank == 0:
|
||||
loss_history.writer.close()
|
||||
|
After Width: | Height: | Size: 41 KiB |
|
After Width: | Height: | Size: 67 KiB |
|
After Width: | Height: | Size: 126 KiB |
@@ -0,0 +1 @@
|
||||
http://localhost/img/000009.jpg
|
||||
|
After Width: | Height: | Size: 174 KiB |
|
After Width: | Height: | Size: 145 KiB |
|
After Width: | Height: | Size: 146 KiB |
@@ -0,0 +1 @@
|
||||
#
|
||||
@@ -0,0 +1,227 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import scipy.signal
|
||||
from matplotlib import pyplot as plt
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from .utils import cvtColor, preprocess_input, resize_image
|
||||
from .utils_bbox import decode_outputs, non_max_suppression
|
||||
from .utils_map import get_coco_map, get_map
|
||||
|
||||
|
||||
class LossHistory():
|
||||
def __init__(self, log_dir, model, input_shape):
|
||||
self.log_dir = log_dir
|
||||
self.losses = []
|
||||
self.val_loss = []
|
||||
|
||||
os.makedirs(self.log_dir)
|
||||
self.writer = SummaryWriter(self.log_dir)
|
||||
try:
|
||||
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
|
||||
self.writer.add_graph(model, dummy_input)
|
||||
except:
|
||||
pass
|
||||
|
||||
def append_loss(self, epoch, loss, val_loss):
|
||||
if not os.path.exists(self.log_dir):
|
||||
os.makedirs(self.log_dir)
|
||||
|
||||
self.losses.append(loss)
|
||||
self.val_loss.append(val_loss)
|
||||
|
||||
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
|
||||
f.write(str(loss))
|
||||
f.write("\n")
|
||||
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
|
||||
f.write(str(val_loss))
|
||||
f.write("\n")
|
||||
|
||||
self.writer.add_scalar('loss', loss, epoch)
|
||||
self.writer.add_scalar('val_loss', val_loss, epoch)
|
||||
self.loss_plot()
|
||||
|
||||
def loss_plot(self):
|
||||
iters = range(len(self.losses))
|
||||
|
||||
plt.figure()
|
||||
plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
|
||||
plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
|
||||
try:
|
||||
if len(self.losses) < 25:
|
||||
num = 5
|
||||
else:
|
||||
num = 15
|
||||
|
||||
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
|
||||
plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
|
||||
except:
|
||||
pass
|
||||
|
||||
plt.grid(True)
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend(loc="upper right")
|
||||
|
||||
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
|
||||
|
||||
plt.cla()
|
||||
plt.close("all")
|
||||
|
||||
class EvalCallback():
|
||||
def __init__(self, net, input_shape, class_names, num_classes, val_lines, log_dir, cuda, \
|
||||
map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
|
||||
super(EvalCallback, self).__init__()
|
||||
|
||||
self.net = net
|
||||
self.input_shape = input_shape
|
||||
self.class_names = class_names
|
||||
self.num_classes = num_classes
|
||||
self.val_lines = val_lines
|
||||
self.log_dir = log_dir
|
||||
self.cuda = cuda
|
||||
self.map_out_path = map_out_path
|
||||
self.max_boxes = max_boxes
|
||||
self.confidence = confidence
|
||||
self.nms_iou = nms_iou
|
||||
self.letterbox_image = letterbox_image
|
||||
self.MINOVERLAP = MINOVERLAP
|
||||
self.eval_flag = eval_flag
|
||||
self.period = period
|
||||
|
||||
self.maps = [0]
|
||||
self.epoches = [0]
|
||||
if self.eval_flag:
|
||||
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
|
||||
f.write(str(0))
|
||||
f.write("\n")
|
||||
|
||||
def get_map_txt(self, image_id, image, class_names, map_out_path):
|
||||
f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
|
||||
image_shape = np.array(np.shape(image)[0:2])
|
||||
#---------------------------------------------------------#
|
||||
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
||||
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
||||
#---------------------------------------------------------#
|
||||
image = cvtColor(image)
|
||||
#---------------------------------------------------------#
|
||||
# 给图像增加灰条,实现不失真的resize
|
||||
# 也可以直接resize进行识别
|
||||
#---------------------------------------------------------#
|
||||
image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
|
||||
#---------------------------------------------------------#
|
||||
# 添加上batch_size维度
|
||||
#---------------------------------------------------------#
|
||||
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
||||
|
||||
with torch.no_grad():
|
||||
images = torch.from_numpy(image_data)
|
||||
if self.cuda:
|
||||
images = images.cuda()
|
||||
#---------------------------------------------------------#
|
||||
# 将图像输入网络当中进行预测!
|
||||
#---------------------------------------------------------#
|
||||
outputs = self.net(images)
|
||||
outputs = decode_outputs(outputs, self.input_shape)
|
||||
#---------------------------------------------------------#
|
||||
# 将预测框进行堆叠,然后进行非极大抑制
|
||||
#---------------------------------------------------------#
|
||||
results = non_max_suppression(outputs, self.num_classes, self.input_shape,
|
||||
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
|
||||
|
||||
if results[0] is None:
|
||||
return
|
||||
|
||||
top_label = np.array(results[0][:, 6], dtype = 'int32')
|
||||
top_conf = results[0][:, 4] * results[0][:, 5]
|
||||
top_boxes = results[0][:, :4]
|
||||
|
||||
top_100 = np.argsort(top_conf)[::-1][:self.max_boxes]
|
||||
top_boxes = top_boxes[top_100]
|
||||
top_conf = top_conf[top_100]
|
||||
top_label = top_label[top_100]
|
||||
|
||||
for i, c in list(enumerate(top_label)):
|
||||
predicted_class = self.class_names[int(c)]
|
||||
box = top_boxes[i]
|
||||
score = str(top_conf[i])
|
||||
|
||||
top, left, bottom, right = box
|
||||
if predicted_class not in class_names:
|
||||
continue
|
||||
|
||||
f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
|
||||
|
||||
f.close()
|
||||
return
|
||||
|
||||
def on_epoch_end(self, epoch, model_eval):
|
||||
if epoch % self.period == 0 and self.eval_flag:
|
||||
self.net = model_eval
|
||||
if not os.path.exists(self.map_out_path):
|
||||
os.makedirs(self.map_out_path)
|
||||
if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
|
||||
os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
|
||||
if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
|
||||
os.makedirs(os.path.join(self.map_out_path, "detection-results"))
|
||||
print("Get map.")
|
||||
for annotation_line in tqdm(self.val_lines):
|
||||
line = annotation_line.split()
|
||||
image_id = os.path.basename(line[0]).split('.')[0]
|
||||
#------------------------------#
|
||||
# 读取图像并转换成RGB图像
|
||||
#------------------------------#
|
||||
image = Image.open(line[0])
|
||||
#------------------------------#
|
||||
# 获得预测框
|
||||
#------------------------------#
|
||||
gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
|
||||
#------------------------------#
|
||||
# 获得预测txt
|
||||
#------------------------------#
|
||||
self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
|
||||
|
||||
#------------------------------#
|
||||
# 获得真实框txt
|
||||
#------------------------------#
|
||||
with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
|
||||
for box in gt_boxes:
|
||||
left, top, right, bottom, obj = box
|
||||
obj_name = self.class_names[obj]
|
||||
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
|
||||
|
||||
print("Calculate Map.")
|
||||
try:
|
||||
temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
|
||||
except:
|
||||
temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
|
||||
self.maps.append(temp_map)
|
||||
self.epoches.append(epoch)
|
||||
|
||||
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
|
||||
f.write(str(temp_map))
|
||||
f.write("\n")
|
||||
|
||||
plt.figure()
|
||||
plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')
|
||||
|
||||
plt.grid(True)
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Map %s'%str(self.MINOVERLAP))
|
||||
plt.title('A Map Curve')
|
||||
plt.legend(loc="upper right")
|
||||
|
||||
plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
|
||||
plt.cla()
|
||||
plt.close("all")
|
||||
|
||||
print("Get map done.")
|
||||
shutil.rmtree(self.map_out_path)
|
||||
@@ -0,0 +1,374 @@
|
||||
from random import sample, shuffle
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
from utils.utils import cvtColor, preprocess_input
|
||||
|
||||
|
||||
class YoloDataset(Dataset):
|
||||
def __init__(self, annotation_lines, input_shape, num_classes, epoch_length, \
|
||||
mosaic, mixup, mosaic_prob, mixup_prob, train, special_aug_ratio = 0.7):
|
||||
super(YoloDataset, self).__init__()
|
||||
self.annotation_lines = annotation_lines
|
||||
self.input_shape = input_shape
|
||||
self.num_classes = num_classes
|
||||
self.epoch_length = epoch_length
|
||||
self.mosaic = mosaic
|
||||
self.mosaic_prob = mosaic_prob
|
||||
self.mixup = mixup
|
||||
self.mixup_prob = mixup_prob
|
||||
self.train = train
|
||||
self.special_aug_ratio = special_aug_ratio
|
||||
|
||||
self.epoch_now = -1
|
||||
self.length = len(self.annotation_lines)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, index):
|
||||
index = index % self.length
|
||||
|
||||
#---------------------------------------------------#
|
||||
# 训练时进行数据的随机增强
|
||||
# 验证时不进行数据的随机增强
|
||||
#---------------------------------------------------#
|
||||
if self.mosaic and self.rand() < self.mosaic_prob and self.epoch_now < self.epoch_length * self.special_aug_ratio:
|
||||
lines = sample(self.annotation_lines, 3)
|
||||
lines.append(self.annotation_lines[index])
|
||||
shuffle(lines)
|
||||
image, box = self.get_random_data_with_Mosaic(lines, self.input_shape)
|
||||
|
||||
if self.mixup and self.rand() < self.mixup_prob:
|
||||
lines = sample(self.annotation_lines, 1)
|
||||
image_2, box_2 = self.get_random_data(lines[0], self.input_shape, random = self.train)
|
||||
image, box = self.get_random_data_with_MixUp(image, box, image_2, box_2)
|
||||
else:
|
||||
image, box = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train)
|
||||
|
||||
image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
|
||||
box = np.array(box, dtype=np.float32)
|
||||
if len(box) != 0:
|
||||
box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
|
||||
box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2
|
||||
return image, box
|
||||
|
||||
def rand(self, a=0, b=1):
|
||||
return np.random.rand()*(b-a) + a
|
||||
|
||||
def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
|
||||
line = annotation_line.split()
|
||||
#------------------------------#
|
||||
# 读取图像并转换成RGB图像
|
||||
#------------------------------#
|
||||
image = Image.open(line[0])
|
||||
image = cvtColor(image)
|
||||
#------------------------------#
|
||||
# 获得图像的高宽与目标高宽
|
||||
#------------------------------#
|
||||
iw, ih = image.size
|
||||
h, w = input_shape
|
||||
#------------------------------#
|
||||
# 获得预测框
|
||||
#------------------------------#
|
||||
box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
|
||||
|
||||
if not random:
|
||||
scale = min(w/iw, h/ih)
|
||||
nw = int(iw*scale)
|
||||
nh = int(ih*scale)
|
||||
dx = (w-nw)//2
|
||||
dy = (h-nh)//2
|
||||
|
||||
#---------------------------------#
|
||||
# 将图像多余的部分加上灰条
|
||||
#---------------------------------#
|
||||
image = image.resize((nw,nh), Image.BICUBIC)
|
||||
new_image = Image.new('RGB', (w,h), (128,128,128))
|
||||
new_image.paste(image, (dx, dy))
|
||||
image_data = np.array(new_image, np.float32)
|
||||
|
||||
#---------------------------------#
|
||||
# 对真实框进行调整
|
||||
#---------------------------------#
|
||||
if len(box)>0:
|
||||
np.random.shuffle(box)
|
||||
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
|
||||
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
|
||||
box[:, 0:2][box[:, 0:2]<0] = 0
|
||||
box[:, 2][box[:, 2]>w] = w
|
||||
box[:, 3][box[:, 3]>h] = h
|
||||
box_w = box[:, 2] - box[:, 0]
|
||||
box_h = box[:, 3] - box[:, 1]
|
||||
box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
|
||||
|
||||
return image_data, box
|
||||
|
||||
#------------------------------------------#
|
||||
# 对图像进行缩放并且进行长和宽的扭曲
|
||||
#------------------------------------------#
|
||||
new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
|
||||
scale = self.rand(.25, 2)
|
||||
if new_ar < 1:
|
||||
nh = int(scale*h)
|
||||
nw = int(nh*new_ar)
|
||||
else:
|
||||
nw = int(scale*w)
|
||||
nh = int(nw/new_ar)
|
||||
image = image.resize((nw,nh), Image.BICUBIC)
|
||||
|
||||
#------------------------------------------#
|
||||
# 将图像多余的部分加上灰条
|
||||
#------------------------------------------#
|
||||
dx = int(self.rand(0, w-nw))
|
||||
dy = int(self.rand(0, h-nh))
|
||||
new_image = Image.new('RGB', (w,h), (128,128,128))
|
||||
new_image.paste(image, (dx, dy))
|
||||
image = new_image
|
||||
|
||||
#------------------------------------------#
|
||||
# 翻转图像
|
||||
#------------------------------------------#
|
||||
flip = self.rand()<.5
|
||||
if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
|
||||
image_data = np.array(image, np.uint8)
|
||||
#---------------------------------#
|
||||
# 对图像进行色域变换
|
||||
# 计算色域变换的参数
|
||||
#---------------------------------#
|
||||
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
|
||||
#---------------------------------#
|
||||
# 将图像转到HSV上
|
||||
#---------------------------------#
|
||||
hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
|
||||
dtype = image_data.dtype
|
||||
#---------------------------------#
|
||||
# 应用变换
|
||||
#---------------------------------#
|
||||
x = np.arange(0, 256, dtype=r.dtype)
|
||||
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
||||
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
||||
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
||||
|
||||
image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
||||
image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
|
||||
|
||||
#---------------------------------#
|
||||
# 对真实框进行调整
|
||||
#---------------------------------#
|
||||
if len(box)>0:
|
||||
np.random.shuffle(box)
|
||||
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
|
||||
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
|
||||
if flip: box[:, [0,2]] = w - box[:, [2,0]]
|
||||
box[:, 0:2][box[:, 0:2]<0] = 0
|
||||
box[:, 2][box[:, 2]>w] = w
|
||||
box[:, 3][box[:, 3]>h] = h
|
||||
box_w = box[:, 2] - box[:, 0]
|
||||
box_h = box[:, 3] - box[:, 1]
|
||||
box = box[np.logical_and(box_w>1, box_h>1)]
|
||||
|
||||
return image_data, box
|
||||
|
||||
def merge_bboxes(self, bboxes, cutx, cuty):
|
||||
merge_bbox = []
|
||||
for i in range(len(bboxes)):
|
||||
for box in bboxes[i]:
|
||||
tmp_box = []
|
||||
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
|
||||
|
||||
if i == 0:
|
||||
if y1 > cuty or x1 > cutx:
|
||||
continue
|
||||
if y2 >= cuty and y1 <= cuty:
|
||||
y2 = cuty
|
||||
if x2 >= cutx and x1 <= cutx:
|
||||
x2 = cutx
|
||||
|
||||
if i == 1:
|
||||
if y2 < cuty or x1 > cutx:
|
||||
continue
|
||||
if y2 >= cuty and y1 <= cuty:
|
||||
y1 = cuty
|
||||
if x2 >= cutx and x1 <= cutx:
|
||||
x2 = cutx
|
||||
|
||||
if i == 2:
|
||||
if y2 < cuty or x2 < cutx:
|
||||
continue
|
||||
if y2 >= cuty and y1 <= cuty:
|
||||
y1 = cuty
|
||||
if x2 >= cutx and x1 <= cutx:
|
||||
x1 = cutx
|
||||
|
||||
if i == 3:
|
||||
if y1 > cuty or x2 < cutx:
|
||||
continue
|
||||
if y2 >= cuty and y1 <= cuty:
|
||||
y2 = cuty
|
||||
if x2 >= cutx and x1 <= cutx:
|
||||
x1 = cutx
|
||||
tmp_box.append(x1)
|
||||
tmp_box.append(y1)
|
||||
tmp_box.append(x2)
|
||||
tmp_box.append(y2)
|
||||
tmp_box.append(box[-1])
|
||||
merge_bbox.append(tmp_box)
|
||||
return merge_bbox
|
||||
|
||||
def get_random_data_with_Mosaic(self, annotation_line, input_shape, jitter=0.3, hue=.1, sat=0.7, val=0.4):
|
||||
h, w = input_shape
|
||||
min_offset_x = self.rand(0.3, 0.7)
|
||||
min_offset_y = self.rand(0.3, 0.7)
|
||||
|
||||
image_datas = []
|
||||
box_datas = []
|
||||
index = 0
|
||||
for line in annotation_line:
|
||||
#---------------------------------#
|
||||
# 每一行进行分割
|
||||
#---------------------------------#
|
||||
line_content = line.split()
|
||||
#---------------------------------#
|
||||
# 打开图片
|
||||
#---------------------------------#
|
||||
image = Image.open(line_content[0])
|
||||
image = cvtColor(image)
|
||||
|
||||
#---------------------------------#
|
||||
# 图片的大小
|
||||
#---------------------------------#
|
||||
iw, ih = image.size
|
||||
#---------------------------------#
|
||||
# 保存框的位置
|
||||
#---------------------------------#
|
||||
box = np.array([np.array(list(map(int,box.split(',')))) for box in line_content[1:]])
|
||||
|
||||
#---------------------------------#
|
||||
# 是否翻转图片
|
||||
#---------------------------------#
|
||||
flip = self.rand()<.5
|
||||
if flip and len(box)>0:
|
||||
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
box[:, [0,2]] = iw - box[:, [2,0]]
|
||||
|
||||
#------------------------------------------#
|
||||
# 对图像进行缩放并且进行长和宽的扭曲
|
||||
#------------------------------------------#
|
||||
new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
|
||||
scale = self.rand(.4, 1)
|
||||
if new_ar < 1:
|
||||
nh = int(scale*h)
|
||||
nw = int(nh*new_ar)
|
||||
else:
|
||||
nw = int(scale*w)
|
||||
nh = int(nw/new_ar)
|
||||
image = image.resize((nw, nh), Image.BICUBIC)
|
||||
|
||||
#-----------------------------------------------#
|
||||
# 将图片进行放置,分别对应四张分割图片的位置
|
||||
#-----------------------------------------------#
|
||||
if index == 0:
|
||||
dx = int(w*min_offset_x) - nw
|
||||
dy = int(h*min_offset_y) - nh
|
||||
elif index == 1:
|
||||
dx = int(w*min_offset_x) - nw
|
||||
dy = int(h*min_offset_y)
|
||||
elif index == 2:
|
||||
dx = int(w*min_offset_x)
|
||||
dy = int(h*min_offset_y)
|
||||
elif index == 3:
|
||||
dx = int(w*min_offset_x)
|
||||
dy = int(h*min_offset_y) - nh
|
||||
|
||||
new_image = Image.new('RGB', (w,h), (128,128,128))
|
||||
new_image.paste(image, (dx, dy))
|
||||
image_data = np.array(new_image)
|
||||
|
||||
index = index + 1
|
||||
box_data = []
|
||||
#---------------------------------#
|
||||
# 对box进行重新处理
|
||||
#---------------------------------#
|
||||
if len(box)>0:
|
||||
np.random.shuffle(box)
|
||||
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
|
||||
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
|
||||
box[:, 0:2][box[:, 0:2]<0] = 0
|
||||
box[:, 2][box[:, 2]>w] = w
|
||||
box[:, 3][box[:, 3]>h] = h
|
||||
box_w = box[:, 2] - box[:, 0]
|
||||
box_h = box[:, 3] - box[:, 1]
|
||||
box = box[np.logical_and(box_w>1, box_h>1)]
|
||||
box_data = np.zeros((len(box),5))
|
||||
box_data[:len(box)] = box
|
||||
|
||||
image_datas.append(image_data)
|
||||
box_datas.append(box_data)
|
||||
|
||||
#---------------------------------#
|
||||
# 将图片分割,放在一起
|
||||
#---------------------------------#
|
||||
cutx = int(w * min_offset_x)
|
||||
cuty = int(h * min_offset_y)
|
||||
|
||||
new_image = np.zeros([h, w, 3])
|
||||
new_image[:cuty, :cutx, :] = image_datas[0][:cuty, :cutx, :]
|
||||
new_image[cuty:, :cutx, :] = image_datas[1][cuty:, :cutx, :]
|
||||
new_image[cuty:, cutx:, :] = image_datas[2][cuty:, cutx:, :]
|
||||
new_image[:cuty, cutx:, :] = image_datas[3][:cuty, cutx:, :]
|
||||
|
||||
new_image = np.array(new_image, np.uint8)
|
||||
#---------------------------------#
|
||||
# 对图像进行色域变换
|
||||
# 计算色域变换的参数
|
||||
#---------------------------------#
|
||||
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
|
||||
#---------------------------------#
|
||||
# 将图像转到HSV上
|
||||
#---------------------------------#
|
||||
hue, sat, val = cv2.split(cv2.cvtColor(new_image, cv2.COLOR_RGB2HSV))
|
||||
dtype = new_image.dtype
|
||||
#---------------------------------#
|
||||
# 应用变换
|
||||
#---------------------------------#
|
||||
x = np.arange(0, 256, dtype=r.dtype)
|
||||
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
||||
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
||||
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
||||
|
||||
new_image = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
||||
new_image = cv2.cvtColor(new_image, cv2.COLOR_HSV2RGB)
|
||||
|
||||
#---------------------------------#
|
||||
# 对框进行进一步的处理
|
||||
#---------------------------------#
|
||||
new_boxes = self.merge_bboxes(box_datas, cutx, cuty)
|
||||
|
||||
return new_image, new_boxes
|
||||
|
||||
def get_random_data_with_MixUp(self, image_1, box_1, image_2, box_2):
|
||||
new_image = np.array(image_1, np.float32) * 0.5 + np.array(image_2, np.float32) * 0.5
|
||||
if len(box_1) == 0:
|
||||
new_boxes = box_2
|
||||
elif len(box_2) == 0:
|
||||
new_boxes = box_1
|
||||
else:
|
||||
new_boxes = np.concatenate([box_1, box_2], axis=0)
|
||||
return new_image, new_boxes
|
||||
|
||||
# DataLoader中collate_fn使用
|
||||
def yolo_dataset_collate(batch):
|
||||
images = []
|
||||
bboxes = []
|
||||
for img, box in batch:
|
||||
images.append(img)
|
||||
bboxes.append(box)
|
||||
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
|
||||
bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
|
||||
return images, bboxes
|
||||
@@ -0,0 +1,64 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
#---------------------------------------------------------#
|
||||
# 将图像转换成RGB图像,防止灰度图在预测时报错。
|
||||
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
||||
#---------------------------------------------------------#
|
||||
def cvtColor(image):
|
||||
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
|
||||
return image
|
||||
else:
|
||||
image = image.convert('RGB')
|
||||
return image
|
||||
|
||||
#---------------------------------------------------#
|
||||
# 对输入图像进行resize
|
||||
#---------------------------------------------------#
|
||||
def resize_image(image, size, letterbox_image):
|
||||
iw, ih = image.size
|
||||
w, h = size
|
||||
if letterbox_image:
|
||||
scale = min(w/iw, h/ih)
|
||||
nw = int(iw*scale)
|
||||
nh = int(ih*scale)
|
||||
|
||||
image = image.resize((nw,nh), Image.BICUBIC)
|
||||
new_image = Image.new('RGB', size, (128,128,128))
|
||||
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
|
||||
else:
|
||||
new_image = image.resize((w, h), Image.BICUBIC)
|
||||
return new_image
|
||||
|
||||
#---------------------------------------------------#
|
||||
# 获得类
|
||||
#---------------------------------------------------#
|
||||
def get_classes(classes_path):
|
||||
with open(classes_path, encoding='utf-8') as f:
|
||||
class_names = f.readlines()
|
||||
class_names = [c.strip() for c in class_names]
|
||||
return class_names, len(class_names)
|
||||
|
||||
def preprocess_input(image):
|
||||
image /= 255.0
|
||||
image -= np.array([0.485, 0.456, 0.406])
|
||||
image /= np.array([0.229, 0.224, 0.225])
|
||||
return image
|
||||
|
||||
#---------------------------------------------------#
|
||||
# 获得学习率
|
||||
#---------------------------------------------------#
|
||||
def get_lr(optimizer):
|
||||
for param_group in optimizer.param_groups:
|
||||
return param_group['lr']
|
||||
|
||||
def show_config(**kwargs):
|
||||
print('Configurations:')
|
||||
print('-' * 70)
|
||||
print('|%25s | %40s|' % ('keys', 'values'))
|
||||
print('-' * 70)
|
||||
for key, value in kwargs.items():
|
||||
print('|%25s | %40s|' % (str(key), str(value)))
|
||||
print('-' * 70)
|
||||
|
||||