From fa8e539a73c48cfe04919ef1dc06f41a5d7886ea Mon Sep 17 00:00:00 2001 From: hengyunabc Date: Wed, 2 Aug 2023 16:00:58 +0800 Subject: [PATCH] SimpleHttpResponse adds deserialization whitelist --- tunnel-common/pom.xml | 12 +++- .../tunnel/common/SimpleHttpResponse.java | 34 ++++++---- .../tunnel/common/SimpleHttpResponseTest.java | 65 +++++++++++++++++++ 3 files changed, 95 insertions(+), 16 deletions(-) create mode 100644 tunnel-common/src/test/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponseTest.java diff --git a/tunnel-common/pom.xml b/tunnel-common/pom.xml index c5a9a2204..440fc4a3a 100644 --- a/tunnel-common/pom.xml +++ b/tunnel-common/pom.xml @@ -12,8 +12,16 @@ https://github.com/alibaba/arthas - - + + junit + junit + test + + + org.assertj + assertj-core + test + diff --git a/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java b/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java index bdaacc20c..8c49c800d 100644 --- a/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java +++ b/tunnel-common/src/main/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponse.java @@ -3,11 +3,14 @@ package com.alibaba.arthas.tunnel.common; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.ObjectInput; +import java.io.InvalidClassException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.io.ObjectStreamClass; import java.io.Serializable; +import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; /** @@ -16,9 +19,11 @@ import java.util.Map; * */ public class SimpleHttpResponse implements Serializable { - private static final long serialVersionUID = 1L; + private static final List whitelist = Arrays.asList(byte[].class.getName(), String.class.getName(), + Map.class.getName(), HashMap.class.getName(), SimpleHttpResponse.class.getName()); + private int status = 200; private Map headers = new HashMap(); @@ -62,27 +67,28 @@ public class SimpleHttpResponse implements Serializable { out.flush(); return bos.toByteArray(); } finally { - try { - bos.close(); - } catch (IOException ex) { - // ignore close exception + if (out != null) { + out.close(); } } } public static SimpleHttpResponse fromBytes(byte[] bytes) throws IOException, ClassNotFoundException { ByteArrayInputStream bis = new ByteArrayInputStream(bytes); - ObjectInput in = null; + ObjectInputStream in = null; try { - in = new ObjectInputStream(bis); + in = new ObjectInputStream(bis) { + protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { + if (!whitelist.contains(desc.getName())) { + throw new InvalidClassException("Unauthorized deserialization attempt", desc.getName()); + } + return super.resolveClass(desc); + } + }; return (SimpleHttpResponse) in.readObject(); } finally { - try { - if (in != null) { - in.close(); - } - } catch (IOException ex) { - // ignore close exception + if (in != null) { + in.close(); } } } diff --git a/tunnel-common/src/test/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponseTest.java b/tunnel-common/src/test/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponseTest.java new file mode 100644 index 000000000..872716305 --- /dev/null +++ b/tunnel-common/src/test/java/com/alibaba/arthas/tunnel/common/SimpleHttpResponseTest.java @@ -0,0 +1,65 @@ +package com.alibaba.arthas.tunnel.common; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InvalidClassException; +import java.io.ObjectOutputStream; +import java.util.Date; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; + +public class SimpleHttpResponseTest { + + @Test + public void testSerialization() throws IOException, ClassNotFoundException { + SimpleHttpResponse response = new SimpleHttpResponse(); + response.setStatus(200); + + Map headers = new HashMap(); + headers.put("Content-Type", "text/plain"); + response.setHeaders(headers); + + String content = "Hello, world!"; + response.setContent(content.getBytes()); + + byte[] bytes = SimpleHttpResponse.toBytes(response); + + SimpleHttpResponse deserializedResponse = SimpleHttpResponse.fromBytes(bytes); + + assertEquals(response.getStatus(), deserializedResponse.getStatus()); + assertEquals(response.getHeaders(), deserializedResponse.getHeaders()); + assertArrayEquals(response.getContent(), deserializedResponse.getContent()); + } + + public static byte[] toBytes(Object response) throws IOException { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream out = null; + try { + out = new ObjectOutputStream(bos); + out.writeObject(response); + out.flush(); + return bos.toByteArray(); + } finally { + if (out != null) { + out.close(); + } + } + } + + @Test(expected = InvalidClassException.class) + public void testDeserializationWithUnauthorizedClass() throws IOException, ClassNotFoundException { + Date date = new Date(); + + byte[] bytes = toBytes(date); + + // Try to deserialize the object with an unauthorized class + // This should throw an InvalidClassException + SimpleHttpResponse.fromBytes(bytes); + } + +}