Feature - allowedClasses setting added to SerializationCodec https://github.com/redisson/redisson/security/code-scanning/4

pull/5079/head
Nikita Koksharov 2 years ago
parent c70943b27e
commit fe6a257180

@ -15,13 +15,11 @@
*/ */
package org.redisson.codec; package org.redisson.codec;
import java.io.IOException; import java.io.*;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectStreamClass;
import java.lang.reflect.Proxy; import java.lang.reflect.Proxy;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Set;
/** /**
* *
@ -31,6 +29,13 @@ import java.util.List;
public class CustomObjectInputStream extends ObjectInputStream { public class CustomObjectInputStream extends ObjectInputStream {
private final ClassLoader classLoader; private final ClassLoader classLoader;
private Set<String> allowedClasses;
public CustomObjectInputStream(ClassLoader classLoader, InputStream in,Set<String> allowedClasses) throws IOException {
super(in);
this.classLoader = classLoader;
this.allowedClasses = allowedClasses;
}
public CustomObjectInputStream(ClassLoader classLoader, InputStream in) throws IOException { public CustomObjectInputStream(ClassLoader classLoader, InputStream in) throws IOException {
super(in); super(in);
@ -41,6 +46,9 @@ public class CustomObjectInputStream extends ObjectInputStream {
protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
try { try {
String name = desc.getName(); String name = desc.getName();
if (allowedClasses != null && !allowedClasses.contains(name)) {
throw new InvalidClassException("Class " + name + " isn't allowed");
}
return Class.forName(name, false, classLoader); return Class.forName(name, false, classLoader);
} catch (ClassNotFoundException e) { } catch (ClassNotFoundException e) {
return super.resolveClass(desc); return super.resolveClass(desc);
@ -56,7 +64,7 @@ public class CustomObjectInputStream extends ObjectInputStream {
loadedClasses.add(clazz); loadedClasses.add(clazz);
} }
return Proxy.getProxyClass(classLoader, loadedClasses.toArray(new Class[loadedClasses.size()])); return Proxy.getProxyClass(classLoader, loadedClasses.toArray(new Class[0]));
} }
} }

@ -15,19 +15,19 @@
*/ */
package org.redisson.codec; package org.redisson.codec;
import java.io.IOException; import io.netty.buffer.ByteBuf;
import java.io.ObjectInputStream; import io.netty.buffer.ByteBufAllocator;
import java.io.ObjectOutputStream; import io.netty.buffer.ByteBufInputStream;
import io.netty.buffer.ByteBufOutputStream;
import org.redisson.client.codec.BaseCodec; import org.redisson.client.codec.BaseCodec;
import org.redisson.client.handler.State; import org.redisson.client.handler.State;
import org.redisson.client.protocol.Decoder; import org.redisson.client.protocol.Decoder;
import org.redisson.client.protocol.Encoder; import org.redisson.client.protocol.Encoder;
import io.netty.buffer.ByteBuf; import java.io.IOException;
import io.netty.buffer.ByteBufAllocator; import java.io.ObjectInputStream;
import io.netty.buffer.ByteBufInputStream; import java.io.ObjectOutputStream;
import io.netty.buffer.ByteBufOutputStream; import java.util.Set;
/** /**
* JDK's serialization codec. * JDK's serialization codec.
@ -51,7 +51,7 @@ public class SerializationCodec extends BaseCodec {
ObjectInputStream inputStream; ObjectInputStream inputStream;
if (classLoader != null) { if (classLoader != null) {
Thread.currentThread().setContextClassLoader(classLoader); Thread.currentThread().setContextClassLoader(classLoader);
inputStream = new CustomObjectInputStream(classLoader, in); inputStream = new CustomObjectInputStream(classLoader, in, allowedClasses);
} else { } else {
inputStream = new ObjectInputStream(in); inputStream = new ObjectInputStream(in);
} }
@ -85,6 +85,7 @@ public class SerializationCodec extends BaseCodec {
} }
}; };
private Set<String> allowedClasses;
private final ClassLoader classLoader; private final ClassLoader classLoader;
public SerializationCodec() { public SerializationCodec() {
@ -97,6 +98,12 @@ public class SerializationCodec extends BaseCodec {
public SerializationCodec(ClassLoader classLoader, SerializationCodec codec) { public SerializationCodec(ClassLoader classLoader, SerializationCodec codec) {
this.classLoader = classLoader; this.classLoader = classLoader;
this.allowedClasses = codec.allowedClasses;
}
public SerializationCodec(ClassLoader classLoader, Set<String> allowedClasses) {
this.classLoader = classLoader;
this.allowedClasses = allowedClasses;
} }
@Override @Override

Loading…
Cancel
Save