diff --git a/tunnel-common/pom.xml b/tunnel-common/pom.xml
index c5a9a2204..440fc4a3a 100644
--- a/tunnel-common/pom.xml
+++ b/tunnel-common/pom.xml
@@ -12,8 +12,16 @@
https://github.com/alibaba/arthas
-
-
+
+ junit
+ junit
+ test
+
+
+ org.assertj
+ assertj-core
+ test
+
diff --git a/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java b/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java
index bdaacc20c..8c49c800d 100644
--- a/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java
+++ b/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java
@@ -3,11 +3,14 @@ package com.alibaba.arthas.tunnel.common;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
-import java.io.ObjectInput;
+import java.io.InvalidClassException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
+import java.io.ObjectStreamClass;
import java.io.Serializable;
+import java.util.Arrays;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
/**
@@ -16,9 +19,11 @@ import java.util.Map;
*
*/
public class SimpleHttpResponse implements Serializable {
-
private static final long serialVersionUID = 1L;
+ private static final List whitelist = Arrays.asList(byte[].class.getName(), String.class.getName(),
+ Map.class.getName(), HashMap.class.getName(), SimpleHttpResponse.class.getName());
+
private int status = 200;
private Map headers = new HashMap();
@@ -62,27 +67,28 @@ public class SimpleHttpResponse implements Serializable {
out.flush();
return bos.toByteArray();
} finally {
- try {
- bos.close();
- } catch (IOException ex) {
- // ignore close exception
+ if (out != null) {
+ out.close();
}
}
}
public static SimpleHttpResponse fromBytes(byte[] bytes) throws IOException, ClassNotFoundException {
ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
- ObjectInput in = null;
+ ObjectInputStream in = null;
try {
- in = new ObjectInputStream(bis);
+ in = new ObjectInputStream(bis) {
+ protected Class> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
+ if (!whitelist.contains(desc.getName())) {
+ throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName());
+ }
+ return super.resolveClass(desc);
+ }
+ };
return (SimpleHttpResponse) in.readObject();
} finally {
- try {
- if (in != null) {
- in.close();
- }
- } catch (IOException ex) {
- // ignore close exception
+ if (in != null) {
+ in.close();
}
}
}
diff --git a/tunnel-common/src/test/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponseTest.java b/tunnel-common/src/test/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponseTest.java
new file mode 100644
index 000000000..872716305
--- /dev/null
+++ b/tunnel-common/src/test/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponseTest.java
@@ -0,0 +1,65 @@
+package com.alibaba.arthas.tunnel.common;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InvalidClassException;
+import java.io.ObjectOutputStream;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.Test;
+
+public class SimpleHttpResponseTest {
+
+ @Test
+ public void testSerialization() throws IOException, ClassNotFoundException {
+ SimpleHttpResponse response = new SimpleHttpResponse();
+ response.setStatus(200);
+
+ Map headers = new HashMap();
+ headers.put("Content-Type", "text/plain");
+ response.setHeaders(headers);
+
+ String content = "Hello, world!";
+ response.setContent(content.getBytes());
+
+ byte[] bytes = SimpleHttpResponse.toBytes(response);
+
+ SimpleHttpResponse deserializedResponse = SimpleHttpResponse.fromBytes(bytes);
+
+ assertEquals(response.getStatus(), deserializedResponse.getStatus());
+ assertEquals(response.getHeaders(), deserializedResponse.getHeaders());
+ assertArrayEquals(response.getContent(), deserializedResponse.getContent());
+ }
+
+ public static byte[] toBytes(Object response) throws IOException {
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ ObjectOutputStream out = null;
+ try {
+ out = new ObjectOutputStream(bos);
+ out.writeObject(response);
+ out.flush();
+ return bos.toByteArray();
+ } finally {
+ if (out != null) {
+ out.close();
+ }
+ }
+ }
+
+ @Test(expected = InvalidClassException.class)
+ public void testDeserializationWithUnauthorizedClass() throws IOException, ClassNotFoundException {
+ Date date = new Date();
+
+ byte[] bytes = toBytes(date);
+
+ // Try to deserialize the object with an unauthorized class
+ // This should throw an InvalidClassException
+ SimpleHttpResponse.fromBytes(bytes);
+ }
+
+}