浏览代码

修复文件参数必填验证失效的`BUG`

mxd 3 年之前
父节点
当前提交
17c9cb9dbb

+ 5 - 1
magic-api-spring-boot-starter/src/main/java/org/ssssssss/magicapi/spring/boot/starter/MagicAPIAutoConfiguration.java

@@ -25,6 +25,7 @@ import org.springframework.jdbc.core.JdbcTemplate;
 import org.springframework.web.bind.annotation.ResponseBody;
 import org.springframework.web.client.RestTemplate;
 import org.springframework.web.multipart.MultipartFile;
+import org.springframework.web.multipart.MultipartResolver;
 import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
 import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry;
 import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
@@ -142,6 +143,9 @@ public class MagicAPIAutoConfiguration implements WebMvcConfigurer, WebSocketCon
 	@Lazy
 	private RequestMappingHandlerMapping requestMappingHandlerMapping;
 
+	@Autowired(required = false)
+	private MultipartResolver multipartResolver;
+
 	private final ObjectProvider<RestTemplate> restTemplateProvider;
 
 	private String ALL_CLASS_TXT;
@@ -457,7 +461,7 @@ public class MagicAPIAutoConfiguration implements WebMvcConfigurer, WebSocketCon
 		logger.info("注册模块:{} -> {}", "env", EnvModule.class);
 		MagicResourceLoader.addModule("env", new EnvModule(environment));
 		logger.info("注册模块:{} -> {}", "request", RequestModule.class);
-		MagicResourceLoader.addModule("request", new RequestModule());
+		MagicResourceLoader.addModule("request", new RequestModule(multipartResolver));
 		logger.info("注册模块:{} -> {}", "response", ResponseModule.class);
 		MagicResourceLoader.addModule("response", new ResponseModule(resultProvider));
 		logger.info("注册模块:{} -> {}", "assert", AssertModule.class);

+ 9 - 1
magic-api/src/main/java/org/ssssssss/magicapi/controller/RequestHandler.java

@@ -173,8 +173,9 @@ public class RequestHandler extends MagicController {
 				}
 
 			} else if (StringUtils.isNotBlank(parameter.getName()) || parameters.containsKey(parameter.getName())) {
+				boolean isFile = parameter.getDataType() == DataType.MultipartFile || parameter.getDataType() == DataType.MultipartFiles;
 				String requestValue = StringUtils.defaultIfBlank(Objects.toString(parameters.get(parameter.getName()), EMPTY), Objects.toString(parameter.getDefaultValue(), EMPTY));
-				if (StringUtils.isBlank(requestValue)) {
+				if (StringUtils.isBlank(requestValue) && !isFile) {
 					if (!parameter.isRequired()) {
 						continue;
 					}
@@ -182,6 +183,11 @@ public class RequestHandler extends MagicController {
 				}
 				try {
 					Object value = convertValue(parameter.getDataType(), parameter.getName(), requestValue);
+					if (isFile && parameter.isRequired()) {
+						if (value == null || (parameter.getDataType() == DataType.MultipartFiles && ((List<?>) value).isEmpty())) {
+							throw new ValidateException(jsonCode, StringUtils.defaultIfBlank(parameter.getError(), String.format("%s[%s]为必填项", comment, parameter.getName())));
+						}
+					}
 					if (VALIDATE_TYPE_PATTERN.equals(parameter.getValidateType())) {    // 正则验证
 						String expression = parameter.getExpression();
 						if (StringUtils.isNotBlank(expression) && !PatternUtils.match(Objects.toString(value, EMPTY), expression)) {
@@ -189,6 +195,8 @@ public class RequestHandler extends MagicController {
 						}
 					}
 					parameters.put(parameter.getName(), value);
+				} catch (ValidateException ve) {
+					throw ve;
 				} catch (Exception e) {
 					throw new ValidateException(jsonCode, StringUtils.defaultIfBlank(parameter.getError(), String.format("%s[%s]不合法", comment, parameter.getName())));
 				}

+ 20 - 7
magic-api/src/main/java/org/ssssssss/magicapi/modules/RequestModule.java

@@ -2,8 +2,7 @@ package org.ssssssss.magicapi.modules;
 
 import org.springframework.web.multipart.MultipartFile;
 import org.springframework.web.multipart.MultipartRequest;
-import org.springframework.web.util.WebUtils;
-import org.ssssssss.magicapi.context.RequestContext;
+import org.springframework.web.multipart.MultipartResolver;
 import org.ssssssss.script.annotation.Comment;
 
 import javax.servlet.http.HttpServletRequest;
@@ -11,12 +10,19 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.Enumeration;
 import java.util.List;
+import java.util.stream.Collectors;
 
 /**
  * request 模块
  */
 public class RequestModule {
 
+	private static MultipartResolver resolver;
+
+	public RequestModule(MultipartResolver resolver) {
+		RequestModule.resolver = resolver;
+	}
+
 	/**
 	 * 获取文件信息
 	 *
@@ -25,7 +31,11 @@ public class RequestModule {
 	@Comment("获取文件")
 	public static MultipartFile getFile(@Comment("参数名") String name) {
 		MultipartRequest request = getMultipartHttpServletRequest();
-		return request == null ? null : request.getFile(name);
+		if (request == null) {
+			return null;
+		}
+		MultipartFile file = request.getFile(name);
+		return file == null || file.isEmpty() ? null : file;
 	}
 
 	/**
@@ -36,20 +46,23 @@ public class RequestModule {
 	@Comment("获取多个文件")
 	public static List<MultipartFile> getFiles(@Comment("参数名") String name) {
 		MultipartRequest request = getMultipartHttpServletRequest();
-		return request == null ? null : request.getFiles(name);
+		if (request == null) {
+			return null;
+		}
+		return request.getFiles(name).stream().filter(it -> !it.isEmpty()).collect(Collectors.toList());
 	}
 
 	/**
 	 * 获取原生HttpServletRequest对象
 	 */
 	public static HttpServletRequest get() {
-		return RequestContext.getHttpServletRequest();
+		return org.ssssssss.magicapi.utils.WebUtils.getRequest().orElse(null);
 	}
 
 	private static MultipartRequest getMultipartHttpServletRequest() {
 		HttpServletRequest request = get();
-		if (request != null && request.getContentType() != null && request.getContentType().toLowerCase().startsWith("multipart/")) {
-			return WebUtils.getNativeRequest(request, MultipartRequest.class);
+		if (request != null && resolver.isMultipart(request)) {
+			return resolver.resolveMultipart(request);
 		}
 		return null;
 	}