Browse Source

代码优化

mxd 3 years ago
parent
commit
0bff6c555f

+ 8 - 1
magic-api/src/main/java/org/ssssssss/magicapi/core/handler/MagicWorkbenchHandler.java

@@ -1,5 +1,7 @@
 package org.ssssssss.magicapi.core.handler;
 
+import org.apache.commons.lang3.StringUtils;
+import org.springframework.http.HttpHeaders;
 import org.ssssssss.magicapi.core.annotation.Message;
 import org.ssssssss.magicapi.core.config.MessageType;
 import org.ssssssss.magicapi.core.config.WebSocketSessionManager;
@@ -8,9 +10,12 @@ import org.ssssssss.magicapi.core.interceptor.AuthorizationInterceptor;
 import org.ssssssss.magicapi.core.context.MagicUser;
 import org.ssssssss.magicapi.core.config.Constants;
 import org.ssssssss.magicapi.core.context.MagicConsoleSession;
+import org.ssssssss.magicapi.modules.servlet.RequestModule;
+import org.ssssssss.magicapi.utils.IpUtils;
 
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Optional;
 import java.util.stream.Collectors;
 
@@ -35,8 +40,10 @@ public class MagicWorkbenchHandler {
 			MagicUser user = guest;
 			if (!authorizationInterceptor.requireLogin() || (user = authorizationInterceptor.getUserByToken(token)) != null) {
 				String ip = Optional.ofNullable(session.getWebSocketSession().getRemoteAddress()).map(it -> it.getAddress().getHostAddress()).orElse("unknown");
+				HttpHeaders headers = session.getWebSocketSession().getHandshakeHeaders();
+				ip = IpUtils.getRealIP(ip, headers::getFirst, null);
 				session.setAttribute(Constants.WEBSOCKET_ATTRIBUTE_USER_ID, user.getId());
-				session.setAttribute(Constants.WEBSOCKET_ATTRIBUTE_USER_IP, ip);
+				session.setAttribute(Constants.WEBSOCKET_ATTRIBUTE_USER_IP, StringUtils.defaultIfBlank(ip, "unknown"));
 				session.setAttribute(Constants.WEBSOCKET_ATTRIBUTE_USER_NAME, user.getUsername());
 				session.setClientId(clientId);
 				session.setActivateTime(System.currentTimeMillis());

+ 2 - 34
magic-api/src/main/java/org/ssssssss/magicapi/modules/servlet/RequestModule.java

@@ -1,10 +1,10 @@
 package org.ssssssss.magicapi.modules.servlet;
 
-import org.apache.commons.lang3.StringUtils;
 import org.springframework.web.multipart.MultipartFile;
 import org.springframework.web.multipart.MultipartRequest;
 import org.springframework.web.multipart.MultipartResolver;
 import org.ssssssss.magicapi.core.annotation.MagicModule;
+import org.ssssssss.magicapi.utils.IpUtils;
 import org.ssssssss.script.annotation.Comment;
 
 import javax.servlet.http.HttpServletRequest;
@@ -109,38 +109,6 @@ public class RequestModule {
 		if (request == null) {
 			return null;
 		}
-		String ip = null;
-		List<String> headers = Stream.concat(Stream.of(DEFAULT_IP_HEADER), Stream.of(otherHeaderNames)).collect(Collectors.toList());
-		for (String header : headers) {
-			if((ip = processIp(request.getHeader(header))) != null){
-				break;
-			}
-		}
-		return ip == null ? processIp(request.getRemoteAddr()) : ip;
-	}
-
-	private String processIp(String ip) {
-		if (ip != null) {
-			ip = ip.trim();
-			if (isUnknown(ip)) {
-				return null;
-			}
-			if (ip.contains(",")) {
-				String[] ips = ip.split(",");
-				for (String subIp : ips) {
-					ip = processIp(subIp);
-					if (ip != null) {
-						return ip;
-					}
-				}
-			}
-			return ip;
-		}
-		return null;
-	}
-
-	private boolean isUnknown(String ip) {
-		return StringUtils.isBlank(ip) || "unknown".equalsIgnoreCase(ip.trim());
+		return IpUtils.getRealIP(request.getRemoteAddr(), request::getHeader, otherHeaderNames);
 	}
-
 }

+ 48 - 0
magic-api/src/main/java/org/ssssssss/magicapi/utils/IpUtils.java

@@ -0,0 +1,48 @@
+package org.ssssssss.magicapi.utils;
+
+import org.apache.commons.lang3.StringUtils;
+
+import java.util.List;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class IpUtils {
+
+	private static final String[] DEFAULT_IP_HEADER = new String[]{"X-Forwarded-For", "X-Real-IP", "Proxy-Client-IP", "WL-Proxy-Client-IP", "HTTP_CLIENT_IP", "HTTP_X_FORWARDED_FOR"};
+
+	public static String getRealIP(String remoteAddr, Function<String, String> getHeader, String ... otherHeaderNames){
+		String ip = null;
+		List<String> headers = Stream.concat(Stream.of(DEFAULT_IP_HEADER), Stream.of(otherHeaderNames)).collect(Collectors.toList());
+		for (String header : headers) {
+			if((ip = processIp(getHeader.apply(header))) != null){
+				break;
+			}
+		}
+		return ip == null ? processIp(remoteAddr) : ip;
+	}
+
+	private static String processIp(String ip) {
+		if (ip != null) {
+			ip = ip.trim();
+			if (isUnknown(ip)) {
+				return null;
+			}
+			if (ip.contains(",")) {
+				String[] ips = ip.split(",");
+				for (String subIp : ips) {
+					ip = processIp(subIp);
+					if (ip != null) {
+						return ip;
+					}
+				}
+			}
+			return ip;
+		}
+		return null;
+	}
+
+	private static  boolean isUnknown(String ip) {
+		return StringUtils.isBlank(ip) || "unknown".equalsIgnoreCase(ip.trim());
+	}
+}