diff --git a/tunnel-common/pom.xml b/tunnel-common/pom.xml
index 8a80f0270..b99802833 100644
--- a/tunnel-common/pom.xml
+++ b/tunnel-common/pom.xml
@@ -12,8 +12,16 @@
https://github.com/alibaba/arthas
-
-
+
+ junit
+ junit
+ test
+
+
+ org.assertj
+ assertj-core
+ test
+
diff --git a/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java b/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java
index bdaacc20c..302216b33 100644
--- a/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java
+++ b/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java
@@ -3,11 +3,14 @@ package com.alibaba.arthas.tunnel.common;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
-import java.io.ObjectInput;
+import java.io.InvalidClassException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
+import java.io.ObjectStreamClass;
import java.io.Serializable;
+import java.util.Arrays;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
/**
@@ -16,9 +19,11 @@ import java.util.Map;
*
*/
public class SimpleHttpResponse implements Serializable {
-
private static final long serialVersionUID = 1L;
+ private static final List whitelist = Arrays.asList(byte[].class.getName(), String.class.getName(),
+ Map.class.getName(), HashMap.class.getName(), SimpleHttpResponse.class.getName());
+
private int status = 200;
private Map headers = new HashMap();
@@ -55,35 +60,25 @@ public class SimpleHttpResponse implements Serializable {
public static byte[] toBytes(SimpleHttpResponse response) throws IOException {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
- ObjectOutputStream out = null;
- try {
- out = new ObjectOutputStream(bos);
+ try (ObjectOutputStream out = new ObjectOutputStream(bos)) {
out.writeObject(response);
out.flush();
return bos.toByteArray();
- } finally {
- try {
- bos.close();
- } catch (IOException ex) {
- // ignore close exception
- }
}
}
public static SimpleHttpResponse fromBytes(byte[] bytes) throws IOException, ClassNotFoundException {
ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
- ObjectInput in = null;
- try {
- in = new ObjectInputStream(bis);
- return (SimpleHttpResponse) in.readObject();
- } finally {
- try {
- if (in != null) {
- in.close();
+ try (ObjectInputStream in = new ObjectInputStream(bis) {
+ protected Class> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
+ if (!whitelist.contains(desc.getName())) {
+ throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName());
}
- } catch (IOException ex) {
- // ignore close exception
+ return super.resolveClass(desc);
}
+ }) {
+ return (SimpleHttpResponse) in.readObject();
}
}
+
}
diff --git a/tunnel-common/src/test/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponseTest.java b/tunnel-common/src/test/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponseTest.java
new file mode 100644
index 000000000..477b7203a
--- /dev/null
+++ b/tunnel-common/src/test/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponseTest.java
@@ -0,0 +1,59 @@
+package com.alibaba.arthas.tunnel.common;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InvalidClassException;
+import java.io.ObjectOutputStream;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.Test;
+
+public class SimpleHttpResponseTest {
+
+ @Test
+ public void testSerialization() throws IOException, ClassNotFoundException {
+ SimpleHttpResponse response = new SimpleHttpResponse();
+ response.setStatus(200);
+
+ Map headers = new HashMap();
+ headers.put("Content-Type", "text/plain");
+ response.setHeaders(headers);
+
+ String content = "Hello, world!";
+ response.setContent(content.getBytes());
+
+ byte[] bytes = SimpleHttpResponse.toBytes(response);
+
+ SimpleHttpResponse deserializedResponse = SimpleHttpResponse.fromBytes(bytes);
+
+ assertEquals(response.getStatus(), deserializedResponse.getStatus());
+ assertEquals(response.getHeaders(), deserializedResponse.getHeaders());
+ assertArrayEquals(response.getContent(), deserializedResponse.getContent());
+ }
+
+ private static byte[] toBytes(Object object) throws IOException {
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ try (ObjectOutputStream out = new ObjectOutputStream(bos)) {
+ out.writeObject(object);
+ out.flush();
+ return bos.toByteArray();
+ }
+ }
+
+ @Test(expected = InvalidClassException.class)
+ public void testDeserializationWithUnauthorizedClass() throws IOException, ClassNotFoundException {
+ Date date = new Date();
+
+ byte[] bytes = toBytes(date);
+
+ // Try to deserialize the object with an unauthorized class
+ // This should throw an InvalidClassException
+ SimpleHttpResponse.fromBytes(bytes);
+ }
+
+}