SimpleHttpResponse adds deserialization whitelist

pull/2633/head
hengyunabc 2 years ago
parent be646f72a1
commit fa8e539a73

@ -12,8 +12,16 @@
<url>https://github.com/alibaba/arthas</url> <url>https://github.com/alibaba/arthas</url>
<dependencies> <dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<build> <build>

@ -3,11 +3,14 @@ package com.alibaba.arthas.tunnel.common;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.ObjectInput; import java.io.InvalidClassException;
import java.io.ObjectInputStream; import java.io.ObjectInputStream;
import java.io.ObjectOutputStream; import java.io.ObjectOutputStream;
import java.io.ObjectStreamClass;
import java.io.Serializable; import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
/** /**
@ -16,9 +19,11 @@ import java.util.Map;
* *
*/ */
public class SimpleHttpResponse implements Serializable { public class SimpleHttpResponse implements Serializable {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
private static final List<String> whitelist = Arrays.asList(byte[].class.getName(), String.class.getName(),
Map.class.getName(), HashMap.class.getName(), SimpleHttpResponse.class.getName());
private int status = 200; private int status = 200;
private Map<String, String> headers = new HashMap<String, String>(); private Map<String, String> headers = new HashMap<String, String>();
@ -62,27 +67,28 @@ public class SimpleHttpResponse implements Serializable {
out.flush(); out.flush();
return bos.toByteArray(); return bos.toByteArray();
} finally { } finally {
try { if (out != null) {
bos.close(); out.close();
} catch (IOException ex) {
// ignore close exception
} }
} }
} }
public static SimpleHttpResponse fromBytes(byte[] bytes) throws IOException, ClassNotFoundException { public static SimpleHttpResponse fromBytes(byte[] bytes) throws IOException, ClassNotFoundException {
ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
ObjectInput in = null; ObjectInputStream in = null;
try { try {
in = new ObjectInputStream(bis); in = new ObjectInputStream(bis) {
protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
if (!whitelist.contains(desc.getName())) {
throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName());
}
return super.resolveClass(desc);
}
};
return (SimpleHttpResponse) in.readObject(); return (SimpleHttpResponse) in.readObject();
} finally { } finally {
try { if (in != null) {
if (in != null) { in.close();
in.close();
}
} catch (IOException ex) {
// ignore close exception
} }
} }
} }

@ -0,0 +1,65 @@
package com.alibaba.arthas.tunnel.common;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InvalidClassException;
import java.io.ObjectOutputStream;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
public class SimpleHttpResponseTest {
@Test
public void testSerialization() throws IOException, ClassNotFoundException {
SimpleHttpResponse response = new SimpleHttpResponse();
response.setStatus(200);
Map<String, String> headers = new HashMap<String, String>();
headers.put("Content-Type", "text/plain");
response.setHeaders(headers);
String content = "Hello, world!";
response.setContent(content.getBytes());
byte[] bytes = SimpleHttpResponse.toBytes(response);
SimpleHttpResponse deserializedResponse = SimpleHttpResponse.fromBytes(bytes);
assertEquals(response.getStatus(), deserializedResponse.getStatus());
assertEquals(response.getHeaders(), deserializedResponse.getHeaders());
assertArrayEquals(response.getContent(), deserializedResponse.getContent());
}
public static byte[] toBytes(Object response) throws IOException {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutputStream out = null;
try {
out = new ObjectOutputStream(bos);
out.writeObject(response);
out.flush();
return bos.toByteArray();
} finally {
if (out != null) {
out.close();
}
}
}
@Test(expected = InvalidClassException.class)
public void testDeserializationWithUnauthorizedClass() throws IOException, ClassNotFoundException {
Date date = new Date();
byte[] bytes = toBytes(date);
// Try to deserialize the object with an unauthorized class
// This should throw an InvalidClassException
SimpleHttpResponse.fromBytes(bytes);
}
}
Loading…
Cancel
Save