Quellcode durchsuchen

增加Request模块

mxd vor 4 Jahren
Ursprung
Commit
514af3e545

+ 1 - 1
src/main/java/org/ssssssss/magicapi/config/MappingHandlerMapping.java

@@ -46,7 +46,7 @@ public class MappingHandlerMapping {
     /**
      * 请求到达时处理的方法
      */
-    private Method method = RequestHandler.class.getDeclaredMethod("invoke", HttpServletRequest.class, HttpServletResponse.class, Map.class, Map.class, Map.class);
+    private Method method = RequestHandler.class.getDeclaredMethod("invoke", HttpServletRequest.class, HttpServletResponse.class, Map.class, Map.class);
 
     /**
      * 接口信息读取

+ 23 - 5
src/main/java/org/ssssssss/magicapi/config/RequestHandler.java

@@ -2,13 +2,17 @@ package org.ssssssss.magicapi.config;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import org.springframework.http.MediaType;
 import org.springframework.http.ResponseEntity;
+import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.http.server.ServletServerHttpRequest;
 import org.springframework.web.bind.annotation.PathVariable;
 import org.springframework.web.bind.annotation.RequestBody;
 import org.springframework.web.bind.annotation.RequestParam;
 import org.springframework.web.bind.annotation.ResponseBody;
 import org.ssssssss.magicapi.context.CookieContext;
 import org.ssssssss.magicapi.context.HeaderContext;
+import org.ssssssss.magicapi.context.RequestContext;
 import org.ssssssss.magicapi.context.SessionContext;
 import org.ssssssss.magicapi.provider.ResultProvider;
 import org.ssssssss.magicapi.script.ScriptManager;
@@ -54,6 +58,12 @@ public class RequestHandler {
 		this.throwException = throwException;
 	}
 
+	private List<HttpMessageConverter<?>> httpMessageConverters;
+
+	public void setHttpMessageConverters(List<HttpMessageConverter<?>> httpMessageConverters) {
+		this.httpMessageConverters = httpMessageConverters;
+	}
+
 	/**
 	 * 打印banner
 	 */
@@ -69,10 +79,10 @@ public class RequestHandler {
 	@ResponseBody
 	public Object invoke(HttpServletRequest request, HttpServletResponse response,
 						 @PathVariable(required = false) Map<String, Object> pathVariables,
-						 @RequestParam(required = false) Map<String, Object> parameters,
-						 @RequestBody(required = false) Map<String, Object> requestBody) throws Throwable {
+						 @RequestParam(required = false) Map<String, Object> parameters) throws Throwable {
 		ApiInfo info;
 		try {
+			RequestContext.setRequestAttribute(request, response);
 			//	找到对应的接口信息
 			info = MappingHandlerMapping.getMappingApiInfo(request);
 			if(info==null){
@@ -87,8 +97,15 @@ public class RequestHandler {
 			context.set("header", new HeaderContext(request));
 			context.set("session", new SessionContext(request.getSession()));
 			context.set("path", pathVariables);
-			if (requestBody != null) {
-				context.set("body", requestBody);
+			if (httpMessageConverters != null && request.getContentType() != null) {
+				MediaType mediaType = MediaType.valueOf(request.getContentType());
+				Class clazz = Map.class;
+				for (HttpMessageConverter<?> converter : httpMessageConverters) {
+					if(converter.canRead(clazz,mediaType)){
+						context.set("body", converter.read(clazz,new ServletServerHttpRequest(request)));
+						break;
+					}
+				}
 			}
 			// 执行前置拦截器
 			for (RequestInterceptor requestInterceptor : requestInterceptors) {
@@ -118,7 +135,8 @@ public class RequestHandler {
 			}
 			logger.error("接口请求出错", root);
 			return resultProvider.buildResult(root);
+		} finally {
+			RequestContext.remove();
 		}
-
 	}
 }

+ 40 - 0
src/main/java/org/ssssssss/magicapi/context/RequestContext.java

@@ -0,0 +1,40 @@
+package org.ssssssss.magicapi.context;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+public class RequestContext {
+
+	private static final ThreadLocal<RequestAttribute> REQUEST_ATTRIBUTE_THREAD_LOCAL = new InheritableThreadLocal<>();
+
+	public static void setRequestAttribute(HttpServletRequest request,HttpServletResponse response){
+		REQUEST_ATTRIBUTE_THREAD_LOCAL.set(new RequestAttribute(request, response));
+	}
+	
+	public static HttpServletRequest getHttpServletRequest(){
+		RequestAttribute requestAttribute = REQUEST_ATTRIBUTE_THREAD_LOCAL.get();
+		return  requestAttribute == null ? null : requestAttribute.request;
+	}
+
+	public static HttpServletResponse getHttpServletResponse(){
+		RequestAttribute requestAttribute = REQUEST_ATTRIBUTE_THREAD_LOCAL.get();
+		return  requestAttribute == null ? null : requestAttribute.response;
+	}
+
+	public static void remove(){
+		REQUEST_ATTRIBUTE_THREAD_LOCAL.remove();
+	}
+
+
+	private static class RequestAttribute{
+
+		private HttpServletRequest request;
+
+		private HttpServletResponse response;
+
+		public RequestAttribute(HttpServletRequest request, HttpServletResponse response) {
+			this.request = request;
+			this.response = response;
+		}
+	}
+}

+ 85 - 0
src/main/java/org/ssssssss/magicapi/functions/RequestFunctions.java

@@ -0,0 +1,85 @@
+package org.ssssssss.magicapi.functions;
+
+import org.springframework.web.context.request.RequestAttributes;
+import org.springframework.web.context.request.RequestContextHolder;
+import org.springframework.web.context.request.ServletRequestAttributes;
+import org.springframework.web.multipart.MultipartFile;
+import org.springframework.web.multipart.MultipartRequest;
+import org.springframework.web.util.WebUtils;
+import org.ssssssss.magicapi.context.RequestContext;
+
+import javax.servlet.http.HttpServletRequest;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.List;
+
+/**
+ * request 模块
+ */
+public class RequestFunctions {
+
+	/**
+	 * 获取文件信息
+	 *
+	 * @param name 参数名
+	 */
+	public MultipartFile getFile(String name) {
+		MultipartRequest request = getMultipartHttpServletRequest();
+		return request == null ? null : request.getFile(name);
+	}
+
+	/**
+	 * 获取文件信息
+	 *
+	 * @param name 参数名
+	 */
+	public List<MultipartFile> getFiles(String name) {
+		MultipartRequest request = getMultipartHttpServletRequest();
+		return request == null ? null : request.getFiles(name);
+	}
+
+	/**
+	 * 根据参数名获取参数值集合
+	 *
+	 * @param name 参数名
+	 */
+	public List<String> getValues(String name) {
+		HttpServletRequest request = get();
+		if (request != null) {
+			String[] values = request.getParameterValues(name);
+			return values == null ? null : Arrays.asList(values);
+		}
+		return null;
+	}
+
+	/**
+	 * 根据header名获取header集合
+	 *
+	 * @param name 参数名
+	 */
+	public List<String> getHeaders(String name) {
+		HttpServletRequest request = get();
+		if (request != null) {
+			Enumeration<String> headers = request.getHeaders(name);
+			return headers == null ? null : Collections.list(headers);
+		}
+		return null;
+	}
+
+	/**
+	 * 获取原生HttpServletRequest对象
+	 */
+	public HttpServletRequest get() {
+		return RequestContext.getHttpServletRequest();
+	}
+
+	private MultipartRequest getMultipartHttpServletRequest() {
+		HttpServletRequest request = get();
+		if (request != null && request.getContentType() != null && request.getContentType().toLowerCase().startsWith("multipart/")) {
+			return WebUtils.getNativeRequest(request, MultipartRequest.class);
+		}
+		return null;
+	}
+
+}