From 99016af42a1b2bec80dd1eecbd5338da58ad2ac5 Mon Sep 17 00:00:00 2001 From: "jinxue.cgh" Date: Thu, 25 Jun 2020 20:47:17 +0800 Subject: [PATCH] RTC: transport use single srtp --- trunk/src/app/srs_app_rtc_conn.cpp | 74 ++++++++++++-------------- trunk/src/app/srs_app_rtc_conn.hpp | 20 ++++--- trunk/src/app/srs_app_rtc_dtls.cpp | 85 ++++++++++++++++-------------- trunk/src/app/srs_app_rtc_dtls.hpp | 22 +++++--- 4 files changed, 108 insertions(+), 93 deletions(-) diff --git a/trunk/src/app/srs_app_rtc_conn.cpp b/trunk/src/app/srs_app_rtc_conn.cpp index 5b2a6b57f..66bf8345b 100644 --- a/trunk/src/app/srs_app_rtc_conn.cpp +++ b/trunk/src/app/srs_app_rtc_conn.cpp @@ -112,8 +112,7 @@ SrsSecurityTransport::SrsSecurityTransport(SrsRtcSession* s) session_ = s; dtls_ = new SrsDtls((ISrsDtlsCallback*)this); - srtp_send = new SrsSRTP(); - srtp_recv = new SrsSRTP(); + srtp_ = new SrsSRTP(); handshake_done = false; } @@ -125,12 +124,9 @@ SrsSecurityTransport::~SrsSecurityTransport() dtls_ = NULL; } - if (srtp_send) { - srs_freep(srtp_send); - } - - if (srtp_recv) { - srs_freep(srtp_recv); + if (srtp_) { + srs_freep(srtp_); + srtp_ = NULL; } } @@ -199,62 +195,58 @@ srs_error_t SrsSecurityTransport::srtp_initialize() if ((err = dtls_->get_srtp_key(recv_key, send_key)) != srs_success) { return err; } - - if ((err = srtp_send->initialize(send_key, true)) != srs_success) { - return srs_error_wrap(err, "srtp send init failed"); - } - - if ((err = srtp_recv->initialize(recv_key, false)) != srs_success) { - return srs_error_wrap(err, "srtp recv init failed"); + + if ((err = srtp_->initialize(recv_key, send_key)) != srs_success) { + return srs_error_wrap(err, "srtp init failed"); } return err; } -srs_error_t SrsSecurityTransport::protect_rtp(char* out_buf, const char* in_buf, int& nb_out_buf) +srs_error_t SrsSecurityTransport::protect_rtp(const char* plaintext, char* cipher, int& nb_cipher) { - if (!srtp_send) { + if (!srtp_) { return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtp protect failed"); } - return srtp_send->protect_rtp(out_buf, in_buf, nb_out_buf); + return srtp_->protect_rtp(plaintext, cipher, nb_cipher); +} + +srs_error_t SrsSecurityTransport::protect_rtcp(const char* plaintext, char* cipher, int& nb_cipher) +{ + if (!srtp_) { + return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtcp protect failed"); + } + + return srtp_->protect_rtcp(plaintext, cipher, nb_cipher); } // TODO: FIXME: Merge with protect_rtp. srs_error_t SrsSecurityTransport::protect_rtp2(void* rtp_hdr, int* len_ptr) { - if (!srtp_send) { + if (!srtp_) { return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtp protect"); } - return srtp_send->protect_rtp2(rtp_hdr, len_ptr); + return srtp_->protect_rtp2(rtp_hdr, len_ptr); } -srs_error_t SrsSecurityTransport::unprotect_rtp(char* out_buf, const char* in_buf, int& nb_out_buf) +srs_error_t SrsSecurityTransport::unprotect_rtp(const char* cipher, char* plaintext, int& nb_plaintext) { - if (!srtp_recv) { + if (!srtp_) { return srs_error_new(ERROR_RTC_SRTP_UNPROTECT, "rtp unprotect failed"); } - return srtp_recv->unprotect_rtp(out_buf, in_buf, nb_out_buf); -} - -srs_error_t SrsSecurityTransport::protect_rtcp(char* out_buf, const char* in_buf, int& nb_out_buf) -{ - if (!srtp_send) { - return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtcp protect failed"); - } - - return srtp_send->protect_rtcp(out_buf, in_buf, nb_out_buf); + return srtp_->unprotect_rtp(cipher, plaintext, nb_plaintext); } -srs_error_t SrsSecurityTransport::unprotect_rtcp(char* out_buf, const char* in_buf, int& nb_out_buf) +srs_error_t SrsSecurityTransport::unprotect_rtcp(const char* cipher, char* plaintext, int& nb_plaintext) { - if (!srtp_recv) { + if (!srtp_) { return srs_error_new(ERROR_RTC_SRTP_UNPROTECT, "rtcp unprotect failed"); } - return srtp_recv->unprotect_rtcp(out_buf, in_buf, nb_out_buf); + return srtp_->unprotect_rtcp(cipher, plaintext, nb_plaintext); } SrsRtcOutgoingInfo::SrsRtcOutgoingInfo() @@ -1059,7 +1051,7 @@ srs_error_t SrsRtcPublisher::send_rtcp_rr(uint32_t ssrc, SrsRtpRingBuffer* rtp_q char protected_buf[kRtpPacketSize]; int nb_protected_buf = stream.pos(); - if ((err = session_->transport_->protect_rtcp(protected_buf, stream.data(), nb_protected_buf)) != srs_success) { + if ((err = session_->transport_->protect_rtcp(stream.data(), protected_buf, nb_protected_buf)) != srs_success) { return srs_error_wrap(err, "protect rtcp rr"); } @@ -1119,7 +1111,7 @@ srs_error_t SrsRtcPublisher::send_rtcp_xr_rrtr(uint32_t ssrc) char protected_buf[kRtpPacketSize]; int nb_protected_buf = stream.pos(); - if ((err = session_->transport_->protect_rtcp(protected_buf, stream.data(), nb_protected_buf)) != srs_success) { + if ((err = session_->transport_->protect_rtcp(stream.data(), protected_buf, nb_protected_buf)) != srs_success) { return srs_error_wrap(err, "protect rtcp xr"); } @@ -1156,7 +1148,7 @@ srs_error_t SrsRtcPublisher::send_rtcp_fb_pli(uint32_t ssrc) char protected_buf[kRtpPacketSize]; int nb_protected_buf = stream.pos(); - if ((err = session_->transport_->protect_rtcp(protected_buf, stream.data(), nb_protected_buf)) != srs_success) { + if ((err = session_->transport_->protect_rtcp(stream.data(), protected_buf, nb_protected_buf)) != srs_success) { return srs_error_wrap(err, "protect rtcp psfb pli"); } @@ -1187,7 +1179,7 @@ srs_error_t SrsRtcPublisher::on_twcc(uint16_t sn) { } int nb_protected_buf = buffer->pos(); char protected_buf[kRtpPacketSize]; - if (session_->transport_->protect_rtcp(protected_buf, pkt, nb_protected_buf) == srs_success) { + if (session_->transport_->protect_rtcp(pkt, protected_buf, nb_protected_buf) == srs_success) { session_->sendonly_skt->sendto(protected_buf, nb_protected_buf, 0); } } @@ -1207,7 +1199,7 @@ srs_error_t SrsRtcPublisher::on_rtp(char* data, int nb_data) // Decrypt the cipher to plaintext RTP data. int nb_unprotected_buf = nb_data; char* unprotected_buf = new char[kRtpPacketSize]; - if ((err = session_->transport_->unprotect_rtp(unprotected_buf, data, nb_unprotected_buf)) != srs_success) { + if ((err = session_->transport_->unprotect_rtp(data, unprotected_buf, nb_unprotected_buf)) != srs_success) { // We try to decode the RTP header for more detail error informations. SrsBuffer b0(data, nb_data); SrsRtpHeader h0; h0.decode(&b0); err = srs_error_wrap(err, "marker=%u, pt=%u, seq=%u, ts=%u, ssrc=%u, pad=%u, payload=%uB", h0.get_marker(), h0.get_payload_type(), @@ -1950,7 +1942,7 @@ srs_error_t SrsRtcSession::on_rtcp(char* data, int nb_data) char unprotected_buf[kRtpPacketSize]; int nb_unprotected_buf = nb_data; - if ((err = transport_->unprotect_rtcp(unprotected_buf, data, nb_unprotected_buf)) != srs_success) { + if ((err = transport_->unprotect_rtcp(data, unprotected_buf, nb_unprotected_buf)) != srs_success) { return srs_error_wrap(err, "rtcp unprotect failed"); } diff --git a/trunk/src/app/srs_app_rtc_conn.hpp b/trunk/src/app/srs_app_rtc_conn.hpp index 09e2fe35c..9a5b71e61 100644 --- a/trunk/src/app/srs_app_rtc_conn.hpp +++ b/trunk/src/app/srs_app_rtc_conn.hpp @@ -111,9 +111,7 @@ class SrsSecurityTransport : public ISrsDtlsCallback private: SrsRtcSession* session_; SrsDtls* dtls_; - SrsSRTP* srtp_send; - SrsSRTP* srtp_recv; - + SrsSRTP* srtp_; bool handshake_done; public: SrsSecurityTransport(SrsRtcSession* s); @@ -124,11 +122,19 @@ public: srs_error_t do_handshake(); srs_error_t on_dtls(char* data, int nb_data); public: - srs_error_t protect_rtp(char* protected_buf, const char* ori_buf, int& nb_protected_buf); + // Encrypt the input plaintext to output cipher with nb_cipher bytes. + // @remark Note that the nb_cipher is the size of input plaintext, and + // it also is the length of output cipher when return. + srs_error_t protect_rtp(const char* plaintext, char* cipher, int& nb_cipher); + srs_error_t protect_rtcp(const char* plaintext, char* cipher, int& nb_cipher); + // Encrypt the input rtp_hdr with *len_ptr bytes. + // @remark the input plaintext and out cipher reuse rtp_hdr. srs_error_t protect_rtp2(void* rtp_hdr, int* len_ptr); - srs_error_t unprotect_rtp(char* unprotected_buf, const char* ori_buf, int& nb_unprotected_buf); - srs_error_t protect_rtcp(char* protected_buf, const char* ori_buf, int& nb_protected_buf); - srs_error_t unprotect_rtcp(char* unprotected_buf, const char* ori_buf, int& nb_unprotected_buf); + // Decrypt the input cipher to output cipher with nb_cipher bytes. + // @remark Note that the nb_plaintext is the size of input cipher, and + // it also is the length of output plaintext when return. + srs_error_t unprotect_rtp(const char* cipher, char* plaintext, int& nb_plaintext); + srs_error_t unprotect_rtcp(const char* cipher, char* plaintext, int& nb_plaintext); // implement ISrsDtlsCallback public: virtual srs_error_t on_dtls_handshake_done(); diff --git a/trunk/src/app/srs_app_rtc_dtls.cpp b/trunk/src/app/srs_app_rtc_dtls.cpp index 895b1bdd8..58953d3ae 100644 --- a/trunk/src/app/srs_app_rtc_dtls.cpp +++ b/trunk/src/app/srs_app_rtc_dtls.cpp @@ -30,6 +30,7 @@ using namespace std; #include #include #include +#include #include #include @@ -457,17 +458,22 @@ srs_error_t SrsDtls::get_srtp_key(std::string& recv_key, std::string& send_key) SrsSRTP::SrsSRTP() { - srtp_ctx = NULL; + recv_ctx_ = NULL; + send_ctx_ = NULL; } SrsSRTP::~SrsSRTP() { - if (srtp_ctx) { - srtp_dealloc(srtp_ctx); + if (recv_ctx_) { + srtp_dealloc(recv_ctx_); + } + + if (send_ctx_) { + srtp_dealloc(send_ctx_); } } -srs_error_t SrsSRTP::initialize(string srtp_key, bool send) +srs_error_t SrsSRTP::initialize(string recv_key, std::string send_key) { srs_error_t err = srs_success; @@ -480,93 +486,94 @@ srs_error_t SrsSRTP::initialize(string srtp_key, bool send) srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtp); srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp); - if (send) { - policy.ssrc.type = ssrc_any_outbound; - } else { - policy.ssrc.type = ssrc_any_inbound; - } - - policy.ssrc.value = 0; // TODO: adjust window_size policy.window_size = 8192; policy.allow_repeat_tx = 1; policy.next = NULL; - //uint8_t *key = new uint8_t[server_key.size()]; - //memcpy(key, server_key.data(), server_key.size()); - uint8_t *key = new uint8_t[srtp_key.size()]; - memcpy(key, srtp_key.data(), srtp_key.size()); - policy.key = key; + // init recv context + policy.ssrc.type = ssrc_any_inbound; + uint8_t *rkey = new uint8_t[recv_key.size()]; + SrsAutoFreeA(uint8_t, rkey); + memcpy(rkey, recv_key.data(), recv_key.size()); + policy.key = rkey; - if (srtp_create(&srtp_ctx, &policy) != srtp_err_status_ok) { - srs_freepa(key); - return srs_error_new(ERROR_RTC_SRTP_INIT, "srtp_create failed"); + if (srtp_create(&recv_ctx_, &policy) != srtp_err_status_ok) { + return srs_error_new(ERROR_RTC_SRTP_INIT, "srtp_create recv failed"); } - srs_freepa(key); + policy.ssrc.type = ssrc_any_outbound; + uint8_t *skey = new uint8_t[send_key.size()]; + SrsAutoFreeA(uint8_t, skey); + memcpy(skey, send_key.data(), send_key.size()); + policy.key = skey; + + if (srtp_create(&send_ctx_, &policy) != srtp_err_status_ok) { + return srs_error_new(ERROR_RTC_SRTP_INIT, "srtp_create recv failed"); + } return err; } -srs_error_t SrsSRTP::protect_rtp(char* out_buf, const char* in_buf, int& nb_out_buf) +srs_error_t SrsSRTP::protect_rtp(const char* plaintext, char* cipher, int& nb_cipher) { srs_error_t err = srs_success; - memcpy(out_buf, in_buf, nb_out_buf); + memcpy(cipher, plaintext, nb_cipher); // TODO: FIXME: Wrap error code. - if (srtp_protect(srtp_ctx, out_buf, &nb_out_buf) != 0) { + if (srtp_protect(send_ctx_, cipher, &nb_cipher) != 0) { return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtp protect failed"); } return err; } -srs_error_t SrsSRTP::protect_rtp2(void* rtp_hdr, int* len_ptr) +srs_error_t SrsSRTP::protect_rtcp(const char* plaintext, char* cipher, int& nb_cipher) { srs_error_t err = srs_success; + memcpy(cipher, plaintext, nb_cipher); // TODO: FIXME: Wrap error code. - if (srtp_protect(srtp_ctx, rtp_hdr, len_ptr) != 0) { - return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtp protect"); + if (srtp_protect_rtcp(send_ctx_, cipher, &nb_cipher) != 0) { + return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtcp protect failed"); } return err; } -srs_error_t SrsSRTP::unprotect_rtp(char* out_buf, const char* in_buf, int& nb_out_buf) +srs_error_t SrsSRTP::protect_rtp2(void* rtp_hdr, int* len_ptr) { srs_error_t err = srs_success; - memcpy(out_buf, in_buf, nb_out_buf); - srtp_err_status_t r0 = srtp_unprotect(srtp_ctx, out_buf, &nb_out_buf); - if (r0 != srtp_err_status_ok) { - return srs_error_new(ERROR_RTC_SRTP_UNPROTECT, "unprotect r0=%u", r0); + // TODO: FIXME: Wrap error code. + if (srtp_protect(send_ctx_, rtp_hdr, len_ptr) != 0) { + return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtp protect"); } return err; } -srs_error_t SrsSRTP::protect_rtcp(char* out_buf, const char* in_buf, int& nb_out_buf) +srs_error_t SrsSRTP::unprotect_rtp(const char* cipher, char* plaintext, int& nb_plaintext) { srs_error_t err = srs_success; - memcpy(out_buf, in_buf, nb_out_buf); - // TODO: FIXME: Wrap error code. - if (srtp_protect_rtcp(srtp_ctx, out_buf, &nb_out_buf) != 0) { - return srs_error_new(ERROR_RTC_SRTP_PROTECT, "rtcp protect failed"); + memcpy(plaintext, cipher, nb_plaintext); + srtp_err_status_t r0 = srtp_unprotect(recv_ctx_, plaintext, &nb_plaintext); + if (r0 != srtp_err_status_ok) { + return srs_error_new(ERROR_RTC_SRTP_UNPROTECT, "unprotect r0=%u", r0); } return err; } -srs_error_t SrsSRTP::unprotect_rtcp(char* out_buf, const char* in_buf, int& nb_out_buf) +srs_error_t SrsSRTP::unprotect_rtcp(const char* cipher, char* plaintext, int& nb_plaintext) { srs_error_t err = srs_success; - memcpy(out_buf, in_buf, nb_out_buf); + memcpy(plaintext, cipher, nb_plaintext); // TODO: FIXME: Wrap error code. - if (srtp_unprotect_rtcp(srtp_ctx, out_buf, &nb_out_buf) != srtp_err_status_ok) { + if (srtp_unprotect_rtcp(recv_ctx_, plaintext, &nb_plaintext) != srtp_err_status_ok) { return srs_error_new(ERROR_RTC_SRTP_UNPROTECT, "rtcp unprotect failed"); } diff --git a/trunk/src/app/srs_app_rtc_dtls.hpp b/trunk/src/app/srs_app_rtc_dtls.hpp index 41eaa7bdb..759042457 100644 --- a/trunk/src/app/srs_app_rtc_dtls.hpp +++ b/trunk/src/app/srs_app_rtc_dtls.hpp @@ -103,18 +103,28 @@ private: class SrsSRTP { private: - srtp_t srtp_ctx; + srtp_t recv_ctx_; + srtp_t send_ctx_; public: SrsSRTP(); virtual ~SrsSRTP(); public: - srs_error_t initialize(std::string srtp_key, bool send); + // Intialize srtp context with recv_key and send_key. + srs_error_t initialize(std::string recv_key, std::string send_key); public: - srs_error_t protect_rtp(char* protected_buf, const char* ori_buf, int& nb_protected_buf); + // Encrypt the input plaintext to output cipher with nb_cipher bytes. + // @remark Note that the nb_cipher is the size of input plaintext, and + // it also is the length of output cipher when return. + srs_error_t protect_rtp(const char* plaintext, char* cipher, int& nb_cipher); + srs_error_t protect_rtcp(const char* plaintext, char* cipher, int& nb_cipher); + // Encrypt the input rtp_hdr with *len_ptr bytes. + // @remark the input plaintext and out cipher reuse rtp_hdr. srs_error_t protect_rtp2(void* rtp_hdr, int* len_ptr); - srs_error_t unprotect_rtp(char* unprotected_buf, const char* ori_buf, int& nb_unprotected_buf); - srs_error_t protect_rtcp(char* protected_buf, const char* ori_buf, int& nb_protected_buf); - srs_error_t unprotect_rtcp(char* unprotected_buf, const char* ori_buf, int& nb_unprotected_buf); + // Decrypt the input cipher to output cipher with nb_cipher bytes. + // @remark Note that the nb_plaintext is the size of input cipher, and + // it also is the length of output plaintext when return. + srs_error_t unprotect_rtp(const char* cipher, char* plaintext, int& nb_plaintext); + srs_error_t unprotect_rtcp(const char* cipher, char* plaintext, int& nb_plaintext); }; #endif