Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue295 #302

Merged
merged 13 commits into from
Mar 31, 2024
163 changes: 160 additions & 3 deletions src/main/java/com/ql/util/express/config/QLExpressRunStrategy.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
package com.ql.util.express.config;

import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import com.ql.util.express.config.whitelist.WhiteChecker;
import com.ql.util.express.exception.QLSecurityRiskException;

/**
* ExpressRunner设置全局生效的配置,直接使用静态方法控制
*/
public class QLExpressRunStrategy {
/**
* 沙箱模式开关
*/
private static boolean sandboxMode = false;
/**
* 编译期类型白名单
* null 表示不进行校验
* 如果编译时发现引用了白名单之外的类, 就会抛出异常
*/
private static List<WhiteChecker> compileWhiteCheckerList = null;

/**
* 预防空指针
*/
Expand All @@ -20,25 +34,83 @@ public class QLExpressRunStrategy {
*/
private static boolean compareNullLessMoreAsFalse = false;

private static ClassLoader customClassLoader = null;

/**
* 禁止调用不安全的方法
*/
private static boolean forbidInvokeSecurityRiskMethods = false;
private static boolean forbidInvokeSecurityRiskMethods = true;
private static boolean forbidInvokeSecurityRiskConstructors = true;

private static final List<String> SECURITY_RISK_METHOD_LIST = new ArrayList<>();
/**
* 黑名单控制
*/
private static final Set<String> SECURITY_RISK_METHOD_LIST = new HashSet<>();
private static final Set<String> SECURE_RISK_CONSTRUCTOR_LIST = new HashSet<>();

/**
* 白名单控制
*/
private static Set<String> SECURE_METHOD_LIST = new HashSet<>();
private static Set<String> SECURE_CONSTRUCTOR_LIST = new HashSet<>();

/**
* 最大申请的数组大小, 默认没有限制
* 防止用户一次性申请过多的内存
* -1 表示没有限制
*/
private static int maxArrLength = -1;

static {
// 系统退出
SECURITY_RISK_METHOD_LIST.add(System.class.getName() + "." + "exit");

// 运行脚本命令
SECURITY_RISK_METHOD_LIST.add(Runtime.getRuntime().getClass().getName() + ".exec");
SECURITY_RISK_METHOD_LIST.add(ProcessBuilder.class.getName() + ".start");

// 反射相关
SECURITY_RISK_METHOD_LIST.add(Method.class.getName() + ".invoke");
SECURITY_RISK_METHOD_LIST.add(Class.class.getName() + ".forName");
SECURITY_RISK_METHOD_LIST.add(ClassLoader.class.getName() + ".loadClass");
SECURITY_RISK_METHOD_LIST.add(ClassLoader.class.getName() + ".findClass");
SECURITY_RISK_METHOD_LIST.add(ClassLoader.class.getName() + ".defineClass");
SECURITY_RISK_METHOD_LIST.add(ClassLoader.class.getName() + ".getSystemClassLoader");

// jndi 相关
SECURITY_RISK_METHOD_LIST.add("javax.naming.InitialContext.lookup");
SECURITY_RISK_METHOD_LIST.add("com.sun.rowset.JdbcRowSetImpl.setDataSourceName");
SECURITY_RISK_METHOD_LIST.add("com.sun.rowset.JdbcRowSetImpl.setAutoCommit");

SECURITY_RISK_METHOD_LIST.add("jdk.jshell.JShell.create");
SECURITY_RISK_METHOD_LIST.add("javax.script.ScriptEngineManager.getEngineByName");
SECURITY_RISK_METHOD_LIST.add("org.springframework.jndi.JndiLocatorDelegate.lookup");

// QLE QLExpressRunStrategy的所有方法
for (Method method : QLExpressRunStrategy.class.getMethods()) {
SECURITY_RISK_METHOD_LIST.add(QLExpressRunStrategy.class.getName() + "." + method.getName());
}
addRiskSecureConstructor(java.lang.ProcessBuilder.class);
addRiskSecureConstructor(java.net.Socket.class);
addRiskSecureConstructor(java.io.File.class);
addRiskSecureConstructor(java.awt.Desktop.class);
addRiskSecureConstructor(java.util.PropertyResourceBundle.class);
addRiskSecureConstructor(java.nio.file.Files.class);
addRiskSecureConstructor(java.nio.file.Path.class);
}

private QLExpressRunStrategy() {
throw new IllegalStateException("Utility class");
}

public static void setSandBoxMode(boolean sandboxMode) {
QLExpressRunStrategy.sandboxMode = sandboxMode;
}

public static boolean isSandboxMode() {
return sandboxMode;
}

public static boolean isCompareNullLessMoreAsFalse() {
return compareNullLessMoreAsFalse;
}
Expand All @@ -55,6 +127,14 @@ public static void setAvoidNullPointer(boolean avoidNullPointer) {
QLExpressRunStrategy.avoidNullPointer = avoidNullPointer;
}

public static ClassLoader getCustomClassLoader() {
return customClassLoader;
}

public static void setCustomClassLoader(ClassLoader customClassLoader) {
QLExpressRunStrategy.customClassLoader = customClassLoader;
}

public static boolean isForbidInvokeSecurityRiskMethods() {
return forbidInvokeSecurityRiskMethods;
}
Expand All @@ -63,6 +143,14 @@ public static void setForbidInvokeSecurityRiskMethods(boolean forbidInvokeSecuri
QLExpressRunStrategy.forbidInvokeSecurityRiskMethods = forbidInvokeSecurityRiskMethods;
}

public static boolean isForbidInvokeSecurityRiskConstructors() {
return forbidInvokeSecurityRiskConstructors;
}

public static void setForbidInvokeSecurityRiskConstructors(boolean forbidInvokeSecurityRiskConstructors) {
QLExpressRunStrategy.forbidInvokeSecurityRiskConstructors = forbidInvokeSecurityRiskConstructors;
}

/**
* TODO 未考虑方法重载的场景
*
Expand All @@ -73,14 +161,83 @@ public static void addSecurityRiskMethod(Class<?> clazz, String methodName) {
QLExpressRunStrategy.SECURITY_RISK_METHOD_LIST.add(clazz.getName() + "." + methodName);
}

public static void setSecureMethods(Set<String> secureMethods) {
SECURE_METHOD_LIST = secureMethods;
}

public static void addSecureMethod(Class<?> clazz, String methodName) {
SECURE_METHOD_LIST.add(clazz.getName() + "." + methodName);
}

public static void addRiskSecureConstructor(Class<?> clazz){
SECURE_RISK_CONSTRUCTOR_LIST.add(clazz.getName());
}
public static void addSecureConstructor(Class<?> clazz) {
SECURE_CONSTRUCTOR_LIST.add(clazz.getName());
}

public static void assertSecurityRiskMethod(Method method) throws QLSecurityRiskException {
if (!forbidInvokeSecurityRiskMethods || method == null) {
return;
}

String fullMethodName = method.getDeclaringClass().getName() + "." + method.getName();
if (SECURE_METHOD_LIST != null && !SECURE_METHOD_LIST.isEmpty()) {
// 有白名单配置时则黑名单失效
if (!SECURE_METHOD_LIST.contains(fullMethodName)) {
throw new QLSecurityRiskException("使用QLExpress调用了不安全的系统方法:" + method);
}
return;
}

if (SECURITY_RISK_METHOD_LIST.contains(fullMethodName)) {
throw new QLSecurityRiskException("使用QLExpress调用了不安全的系统方法:" + method);
}
}

public static void assertSecurityRiskConstructor(Constructor constructor) throws QLSecurityRiskException {
if (!forbidInvokeSecurityRiskConstructors || constructor == null) {
return;
}
String fullConstructorName = constructor.getDeclaringClass().getName();
if (SECURE_CONSTRUCTOR_LIST != null && !SECURE_CONSTRUCTOR_LIST.isEmpty()) {
// 有白名单配置时则黑名单失效
if (!SECURE_CONSTRUCTOR_LIST.contains(fullConstructorName)) {
throw new QLSecurityRiskException("使用QLExpress调用了不安全的系统构造函數:" + constructor);
}
return;
}

if (SECURE_RISK_CONSTRUCTOR_LIST.contains(fullConstructorName)) {
throw new QLSecurityRiskException("使用QLExpress调用了不安全的系统构造函數:" + constructor);
}
}

/**
* @param clazz
* @return true 表示位于白名单中, false 表示不在白名单中
*/
public static boolean checkWhiteClassList(Class<?> clazz) {
if (compileWhiteCheckerList == null) {
return true;
}
for (WhiteChecker whiteChecker : compileWhiteCheckerList) {
if (whiteChecker.check(clazz)) {
return true;
}
}
return false;
}

public static void setCompileWhiteCheckerList(List<WhiteChecker> compileWhiteCheckerList) {
QLExpressRunStrategy.compileWhiteCheckerList = compileWhiteCheckerList;
}

public static void setMaxArrLength(int maxArrLength) {
QLExpressRunStrategy.maxArrLength = maxArrLength;
}

public static boolean checkArrLength(int arrLen) {
return QLExpressRunStrategy.maxArrLength == -1 || arrLen <= QLExpressRunStrategy.maxArrLength;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.ql.util.express.ExpressUtil;
import com.ql.util.express.InstructionSetContext;
import com.ql.util.express.OperateData;
import com.ql.util.express.config.QLExpressRunStrategy;
import com.ql.util.express.exception.QLException;
import com.ql.util.express.instruction.OperateDataCacheManager;

Expand Down Expand Up @@ -55,7 +56,7 @@ public OperateData executeInner(InstructionSetContext parent, ArraySwap list) th
s.append(")");
throw new QLException(s.toString());
}

QLExpressRunStrategy.assertSecurityRiskConstructor(c);
tmpObj = c.newInstance(objs);
return OperateDataCacheManager.fetchOperateData(tmpObj, obj);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package com.ql.util.express.bugfix;

import com.ql.util.express.DefaultContext;
import com.ql.util.express.ExpressRunner;
import com.ql.util.express.config.QLExpressRunStrategy;
import com.ql.util.express.example.CustBean;
import com.ql.util.express.exception.QLException;
import com.ql.util.express.exception.QLSecurityRiskException;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class InvokeSecurityRiskConstructorsTest {
public InvokeSecurityRiskConstructorsTest(){

}
private boolean preForbidInvokeSecurityRiskConstructors;

@Before
public void before() {
preForbidInvokeSecurityRiskConstructors = QLExpressRunStrategy.isForbidInvokeSecurityRiskConstructors();

//系统默认阻止的方法黑名单:System.exit(1);Runtime.getRuntime().exec()两个函数
QLExpressRunStrategy.setForbidInvokeSecurityRiskConstructors(true);

//白名单
QLExpressRunStrategy.addSecureConstructor(InvokeSecurityRiskConstructorsTest.class);
QLExpressRunStrategy.addSecureConstructor(CustBean.class);
QLExpressRunStrategy.addSecureConstructor(java.util.Date.class);
QLExpressRunStrategy.addSecureConstructor(java.util.LinkedList.class);

//QLExpressRunStrategy.addRiskSecureConstructor(InvokeSecurityRiskConstructorsTest.class);
}

@Test
public void test() throws Exception {
ExpressRunner expressRunner = new ExpressRunner();
DefaultContext<String, Object> context = new DefaultContext<>();

String[] expressList = new String[] {
"import com.ql.util.express.bugfix.InvokeSecurityRiskConstructorsTest;" +
"InvokeSecurityRiskConstructorsTest w = new InvokeSecurityRiskConstructorsTest();return w;"
, "import com.ql.util.express.bugfix.InvokeSecurityRiskMethodsTest;" +
"InvokeSecurityRiskMethodsTest w = new InvokeSecurityRiskMethodsTest();"};

Object result = expressRunner.execute(expressList[0], context, null, true, false, 1000);
Assert.assertTrue(result instanceof InvokeSecurityRiskConstructorsTest);

try {
result = expressRunner.execute(expressList[1], context, null, true, false, 1000);
Assert.fail();
}catch (QLException e) {
//预期内走这里
Assert.assertEquals(e.getCause().getMessage(), "使用QLExpress调用了不安全的系统构造函數:public com.ql.util.express.bugfix.InvokeSecurityRiskMethodsTest()");
}
}

@Test
public void testDefault() throws Exception {
ExpressRunner expressRunner = new ExpressRunner();
DefaultContext<String, Object> context = new DefaultContext<>();
String[] expressList = new String[] {
"import java.net.Socket;" +
"return new Socket();"};

try {
Object result = expressRunner.execute(expressList[0], context, null, false, false, 1000);
Assert.fail();
}catch (QLException e) {
//预期内走这里
Assert.assertEquals(e.getCause().getMessage(), "使用QLExpress调用了不安全的系统构造函數:public java.net.Socket()");
}
QLExpressRunStrategy.addSecureConstructor(java.net.Socket.class);

Object result = expressRunner.execute(expressList[0], context, null, true, false, 1000);
Assert.assertTrue(result instanceof java.net.Socket);
}
}