SimpleHttpResponse adds deserialization whitelist

pull/2608/head
hengyunabc 2 years ago
parent af70d95383
commit 76fef20ff6

@ -12,8 +12,16 @@
<url>https://github.com/alibaba/arthas</url> <url>https://github.com/alibaba/arthas</url>
<dependencies> <dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
</project> </project>

@ -3,11 +3,14 @@ package com.alibaba.arthas.tunnel.common;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.ObjectInput; import java.io.InvalidClassException;
import java.io.ObjectInputStream; import java.io.ObjectInputStream;
import java.io.ObjectOutputStream; import java.io.ObjectOutputStream;
import java.io.ObjectStreamClass;
import java.io.Serializable; import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
/** /**
@ -16,9 +19,11 @@ import java.util.Map;
* *
*/ */
public class SimpleHttpResponse implements Serializable { public class SimpleHttpResponse implements Serializable {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
private static final List<String> whitelist = Arrays.asList(byte[].class.getName(), String.class.getName(),
Map.class.getName(), HashMap.class.getName(), SimpleHttpResponse.class.getName());
private int status = 200; private int status = 200;
private Map<String, String> headers = new HashMap<String, String>(); private Map<String, String> headers = new HashMap<String, String>();
@ -55,35 +60,25 @@ public class SimpleHttpResponse implements Serializable {
public static byte[] toBytes(SimpleHttpResponse response) throws IOException { public static byte[] toBytes(SimpleHttpResponse response) throws IOException {
ByteArrayOutputStream bos = new ByteArrayOutputStream(); ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutputStream out = null; try (ObjectOutputStream out = new ObjectOutputStream(bos)) {
try {
out = new ObjectOutputStream(bos);
out.writeObject(response); out.writeObject(response);
out.flush(); out.flush();
return bos.toByteArray(); return bos.toByteArray();
} finally {
try {
bos.close();
} catch (IOException ex) {
// ignore close exception
}
} }
} }
public static SimpleHttpResponse fromBytes(byte[] bytes) throws IOException, ClassNotFoundException { public static SimpleHttpResponse fromBytes(byte[] bytes) throws IOException, ClassNotFoundException {
ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
ObjectInput in = null; try (ObjectInputStream in = new ObjectInputStream(bis) {
try { protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
in = new ObjectInputStream(bis); if (!whitelist.contains(desc.getName())) {
return (SimpleHttpResponse) in.readObject(); throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName());
} finally {
try {
if (in != null) {
in.close();
} }
} catch (IOException ex) { return super.resolveClass(desc);
// ignore close exception
} }
}) {
return (SimpleHttpResponse) in.readObject();
} }
} }
} }

@ -0,0 +1,59 @@
package com.alibaba.arthas.tunnel.common;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InvalidClassException;
import java.io.ObjectOutputStream;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
public class SimpleHttpResponseTest {
@Test
public void testSerialization() throws IOException, ClassNotFoundException {
SimpleHttpResponse response = new SimpleHttpResponse();
response.setStatus(200);
Map<String, String> headers = new HashMap<String, String>();
headers.put("Content-Type", "text/plain");
response.setHeaders(headers);
String content = "Hello, world!";
response.setContent(content.getBytes());
byte[] bytes = SimpleHttpResponse.toBytes(response);
SimpleHttpResponse deserializedResponse = SimpleHttpResponse.fromBytes(bytes);
assertEquals(response.getStatus(), deserializedResponse.getStatus());
assertEquals(response.getHeaders(), deserializedResponse.getHeaders());
assertArrayEquals(response.getContent(), deserializedResponse.getContent());
}
private static byte[] toBytes(Object object) throws IOException {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
try (ObjectOutputStream out = new ObjectOutputStream(bos)) {
out.writeObject(object);
out.flush();
return bos.toByteArray();
}
}
@Test(expected = InvalidClassException.class)
public void testDeserializationWithUnauthorizedClass() throws IOException, ClassNotFoundException {
Date date = new Date();
byte[] bytes = toBytes(date);
// Try to deserialize the object with an unauthorized class
// This should throw an InvalidClassException
SimpleHttpResponse.fromBytes(bytes);
}
}
Loading…
Cancel
Save