SquashSRS4: Refine DTLS init, use specified API by role

pull/2252/head
winlin 4 years ago
parent de65a331f1
commit dc93836489

@ -47,8 +47,11 @@ public:
virtual ~ISrsDisposingHandler(); virtual ~ISrsDisposingHandler();
public: public:
// When before disposing resource, trigger when manager.remove(c), sync API. // When before disposing resource, trigger when manager.remove(c), sync API.
// @remark Recommend to unref c, after this, no other objects refs to c.
virtual void on_before_dispose(ISrsResource* c) = 0; virtual void on_before_dispose(ISrsResource* c) = 0;
// When disposing resource, async API, c is freed after it. // When disposing resource, async API, c is freed after it.
// @remark Recommend to stop any thread/timer of c, after this, fields of c is able
// to be deleted in any order.
virtual void on_disposing(ISrsResource* c) = 0; virtual void on_disposing(ISrsResource* c) = 0;
}; };

@ -96,16 +96,24 @@ void ssl_on_info(const SSL* dtls, int where, int ret)
} }
} }
SSL_CTX* srs_build_dtls_ctx(SrsDtlsVersion version) SSL_CTX* srs_build_dtls_ctx(SrsDtlsVersion version, std::string role)
{ {
SSL_CTX* dtls_ctx; SSL_CTX* dtls_ctx;
#if OPENSSL_VERSION_NUMBER < 0x10002000L // v1.0.2 #if OPENSSL_VERSION_NUMBER < 0x10002000L // v1.0.2
dtls_ctx = SSL_CTX_new(DTLSv1_method()); dtls_ctx = SSL_CTX_new(DTLSv1_method());
#else #else
if (version == SrsDtlsVersion1_0) { if (version == SrsDtlsVersion1_0) {
dtls_ctx = SSL_CTX_new(DTLSv1_method()); if (role == "active") {
dtls_ctx = SSL_CTX_new(DTLSv1_client_method());
} else {
dtls_ctx = SSL_CTX_new(DTLSv1_server_method());
}
} else if (version == SrsDtlsVersion1_2) { } else if (version == SrsDtlsVersion1_2) {
dtls_ctx = SSL_CTX_new(DTLSv1_2_method()); if (role == "active") {
dtls_ctx = SSL_CTX_new(DTLS_client_method());
} else {
dtls_ctx = SSL_CTX_new(DTLS_server_method());
}
} else { } else {
// SrsDtlsVersionAuto, use version-flexible DTLS methods // SrsDtlsVersionAuto, use version-flexible DTLS methods
dtls_ctx = SSL_CTX_new(DTLS_method()); dtls_ctx = SSL_CTX_new(DTLS_method());
@ -397,7 +405,7 @@ SrsDtlsImpl::~SrsDtlsImpl()
srs_freepa(last_outgoing_packet_cache); srs_freepa(last_outgoing_packet_cache);
} }
srs_error_t SrsDtlsImpl::initialize(std::string version) srs_error_t SrsDtlsImpl::initialize(std::string version, std::string role)
{ {
srs_error_t err = srs_success; srs_error_t err = srs_success;
@ -409,7 +417,7 @@ srs_error_t SrsDtlsImpl::initialize(std::string version)
version_ = SrsDtlsVersionAuto; version_ = SrsDtlsVersionAuto;
} }
dtls_ctx = srs_build_dtls_ctx(version_); dtls_ctx = srs_build_dtls_ctx(version_, role);
if ((dtls = SSL_new(dtls_ctx)) == NULL) { if ((dtls = SSL_new(dtls_ctx)) == NULL) {
return srs_error_new(ERROR_OpenSslCreateSSL, "SSL_new dtls"); return srs_error_new(ERROR_OpenSslCreateSSL, "SSL_new dtls");
@ -418,6 +426,11 @@ srs_error_t SrsDtlsImpl::initialize(std::string version)
SSL_set_ex_data(dtls, 0, this); SSL_set_ex_data(dtls, 0, this);
SSL_set_info_callback(dtls, ssl_on_info); SSL_set_info_callback(dtls, ssl_on_info);
// set dtls fragment
// @see https://stackoverflow.com/questions/62413602/openssl-server-packets-get-fragmented-into-270-bytes-per-packet
SSL_set_options(dtls, SSL_OP_NO_QUERY_MTU);
SSL_set_mtu(dtls, kRtpPacketSize);
if ((bio_in = BIO_new(BIO_s_mem())) == NULL) { if ((bio_in = BIO_new(BIO_s_mem())) == NULL) {
return srs_error_new(ERROR_OpenSslBIONew, "BIO_new in"); return srs_error_new(ERROR_OpenSslBIONew, "BIO_new in");
} }
@ -643,11 +656,11 @@ SrsDtlsClientImpl::~SrsDtlsClientImpl()
srs_freep(trd); srs_freep(trd);
} }
srs_error_t SrsDtlsClientImpl::initialize(std::string version) srs_error_t SrsDtlsClientImpl::initialize(std::string version, std::string role)
{ {
srs_error_t err = srs_success; srs_error_t err = srs_success;
if ((err = SrsDtlsImpl::initialize(version)) != srs_success) { if ((err = SrsDtlsImpl::initialize(version, role)) != srs_success) {
return err; return err;
} }
@ -819,11 +832,11 @@ SrsDtlsServerImpl::~SrsDtlsServerImpl()
{ {
} }
srs_error_t SrsDtlsServerImpl::initialize(std::string version) srs_error_t SrsDtlsServerImpl::initialize(std::string version, std::string role)
{ {
srs_error_t err = srs_success; srs_error_t err = srs_success;
if ((err = SrsDtlsImpl::initialize(version)) != srs_success) { if ((err = SrsDtlsImpl::initialize(version, role)) != srs_success) {
return err; return err;
} }
@ -892,7 +905,7 @@ srs_error_t SrsDtls::initialize(std::string role, std::string version)
impl = new SrsDtlsServerImpl(callback_); impl = new SrsDtlsServerImpl(callback_);
} }
return impl->initialize(version); return impl->initialize(version, role);
} }
srs_error_t SrsDtls::start_active_handshake() srs_error_t SrsDtls::start_active_handshake()

@ -130,7 +130,7 @@ public:
SrsDtlsImpl(ISrsDtlsCallback* callback); SrsDtlsImpl(ISrsDtlsCallback* callback);
virtual ~SrsDtlsImpl(); virtual ~SrsDtlsImpl();
public: public:
virtual srs_error_t initialize(std::string version); virtual srs_error_t initialize(std::string version, std::string role);
virtual srs_error_t start_active_handshake() = 0; virtual srs_error_t start_active_handshake() = 0;
virtual srs_error_t on_dtls(char* data, int nb_data); virtual srs_error_t on_dtls(char* data, int nb_data);
protected: protected:
@ -162,7 +162,7 @@ public:
SrsDtlsClientImpl(ISrsDtlsCallback* callback); SrsDtlsClientImpl(ISrsDtlsCallback* callback);
virtual ~SrsDtlsClientImpl(); virtual ~SrsDtlsClientImpl();
public: public:
virtual srs_error_t initialize(std::string version); virtual srs_error_t initialize(std::string version, std::string role);
virtual srs_error_t start_active_handshake(); virtual srs_error_t start_active_handshake();
virtual srs_error_t on_dtls(char* data, int nb_data); virtual srs_error_t on_dtls(char* data, int nb_data);
protected: protected:
@ -183,7 +183,7 @@ public:
SrsDtlsServerImpl(ISrsDtlsCallback* callback); SrsDtlsServerImpl(ISrsDtlsCallback* callback);
virtual ~SrsDtlsServerImpl(); virtual ~SrsDtlsServerImpl();
public: public:
virtual srs_error_t initialize(std::string version); virtual srs_error_t initialize(std::string version, std::string role);
virtual srs_error_t start_active_handshake(); virtual srs_error_t start_active_handshake();
protected: protected:
virtual void on_ssl_out_data(uint8_t*& data, int& size, bool& cached); virtual void on_ssl_out_data(uint8_t*& data, int& size, bool& cached);

@ -163,6 +163,8 @@ SrsFastCoroutine::SrsFastCoroutine(string n, ISrsCoroutineHandler* h, SrsContext
SrsFastCoroutine::~SrsFastCoroutine() SrsFastCoroutine::~SrsFastCoroutine()
{ {
stop(); stop();
// TODO: FIXME: We must assert the cycle is done.
srs_freep(trd_err); srs_freep(trd_err);
} }
@ -213,7 +215,7 @@ void SrsFastCoroutine::stop()
interrupt(); interrupt();
// When not started, the rd is NULL. // When not started, the trd is NULL.
if (trd) { if (trd) {
void* res = NULL; void* res = NULL;
int r0 = st_thread_join((st_thread_t)trd, &res); int r0 = st_thread_join((st_thread_t)trd, &res);
@ -245,7 +247,9 @@ void SrsFastCoroutine::interrupt()
if (trd_err == srs_success) { if (trd_err == srs_success) {
trd_err = srs_error_new(ERROR_THREAD_INTERRUPED, "interrupted"); trd_err = srs_error_new(ERROR_THREAD_INTERRUPED, "interrupted");
} }
// Note that if another thread is stopping thread and waiting in st_thread_join,
// the interrupt will make the st_thread_join fail.
st_thread_interrupt((st_thread_t)trd); st_thread_interrupt((st_thread_t)trd);
} }

@ -563,7 +563,7 @@ VOID TEST(KernelRTCTest, StringDumpHexTest)
} }
} }
extern SSL_CTX* srs_build_dtls_ctx(SrsDtlsVersion version); extern SSL_CTX* srs_build_dtls_ctx(SrsDtlsVersion version, std::string role);
class MockDtls class MockDtls
{ {
@ -625,7 +625,7 @@ srs_error_t MockDtls::initialize(std::string role, std::string version)
version_ = SrsDtlsVersionAuto; version_ = SrsDtlsVersionAuto;
} }
dtls_ctx = srs_build_dtls_ctx(version_); dtls_ctx = srs_build_dtls_ctx(version_, role);
dtls = SSL_new(dtls_ctx); dtls = SSL_new(dtls_ctx);
srs_assert(dtls); srs_assert(dtls);

Loading…
Cancel
Save