diff --git a/trunk/src/protocol/srs_http_stack.cpp b/trunk/src/protocol/srs_http_stack.cpp index b798c7de7..a95dc7ea2 100644 --- a/trunk/src/protocol/srs_http_stack.cpp +++ b/trunk/src/protocol/srs_http_stack.cpp @@ -156,14 +156,23 @@ void SrsHttpHeader::set(string key, string value) string SrsHttpHeader::get(string key) { std::string v; - - if (headers.find(key) != headers.end()) { - v = headers[key]; + + map::iterator it = headers.find(key); + if (it != headers.end()) { + v = it->second; } return v; } +void SrsHttpHeader::del(string key) +{ + map::iterator it = headers.find(key); + if (it != headers.end()) { + headers.erase(it); + } +} + int64_t SrsHttpHeader::content_length() { std::string cl = get("Content-Length"); @@ -192,7 +201,7 @@ void SrsHttpHeader::set_content_type(string ct) void SrsHttpHeader::write(stringstream& ss) { - std::map::iterator it; + map::iterator it; for (it = headers.begin(); it != headers.end(); ++it) { ss << it->first << ": " << it->second << SRS_HTTP_CRLF; } diff --git a/trunk/src/protocol/srs_http_stack.hpp b/trunk/src/protocol/srs_http_stack.hpp index 062178d57..f62dcb2c7 100644 --- a/trunk/src/protocol/srs_http_stack.hpp +++ b/trunk/src/protocol/srs_http_stack.hpp @@ -121,6 +121,9 @@ public: // To access multiple values of a key, access the map directly // with CanonicalHeaderKey. virtual std::string get(std::string key); + // Delete the http header indicated by key. + // Return the removed header field. + virtual void del(std::string); public: // Get the content length. -1 if not set. virtual int64_t content_length(); diff --git a/trunk/src/service/srs_service_http_conn.cpp b/trunk/src/service/srs_service_http_conn.cpp index c7cd873e0..698d4b4a1 100644 --- a/trunk/src/service/srs_service_http_conn.cpp +++ b/trunk/src/service/srs_service_http_conn.cpp @@ -614,7 +614,15 @@ bool SrsHttpMessage::is_jsonp() return jsonp; } -SrsHttpResponseWriter::SrsHttpResponseWriter(SrsStSocket* io) +ISrsHttpHeaderFilter::ISrsHttpHeaderFilter() +{ +} + +ISrsHttpHeaderFilter::~ISrsHttpHeaderFilter() +{ +} + +SrsHttpResponseWriter::SrsHttpResponseWriter(ISrsProtocolReadWriter* io) { skt = io; hdr = new SrsHttpHeader(); @@ -625,6 +633,7 @@ SrsHttpResponseWriter::SrsHttpResponseWriter(SrsStSocket* io) header_sent = false; nb_iovss_cache = 0; iovss_cache = NULL; + hf = NULL; } SrsHttpResponseWriter::~SrsHttpResponseWriter() @@ -840,6 +849,11 @@ srs_error_t SrsHttpResponseWriter::send_header(char* data, int size) // keep alive to make vlc happy. hdr->set("Connection", "Keep-Alive"); + + // Filter the header before writing it. + if (hf && ((err = hf->filter(hdr)) != srs_success)) { + return srs_error_wrap(err, "filter header"); + } // write headers hdr->write(ss); diff --git a/trunk/src/service/srs_service_http_conn.hpp b/trunk/src/service/srs_service_http_conn.hpp index 3bfdfb4ee..4ac4d043f 100644 --- a/trunk/src/service/srs_service_http_conn.hpp +++ b/trunk/src/service/srs_service_http_conn.hpp @@ -35,7 +35,7 @@ class SrsFastStream; class SrsRequest; class ISrsReader; class SrsHttpResponseReader; -class SrsStSocket; +class ISrsProtocolReadWriter; // A wrapper for http-parser, // provides HTTP message originted service. @@ -195,12 +195,25 @@ public: // for writev, there always one chunk to send it. #define SRS_HTTP_HEADER_CACHE_SIZE 64 +class ISrsHttpHeaderFilter +{ +public: + ISrsHttpHeaderFilter(); + virtual ~ISrsHttpHeaderFilter(); +public: + // Filter the HTTP header h. + virtual srs_error_t filter(SrsHttpHeader* h) = 0; +}; + // Response writer use st socket class SrsHttpResponseWriter : public ISrsHttpResponseWriter { private: - SrsStSocket* skt; + ISrsProtocolReadWriter* skt; SrsHttpHeader* hdr; + // Before writing header, there is a chance to filter it, + // such as remove some headers or inject new. + ISrsHttpHeaderFilter* hf; private: char header_cache[SRS_HTTP_HEADER_CACHE_SIZE]; iovec* iovss_cache; @@ -222,7 +235,7 @@ private: // logically written. bool header_sent; public: - SrsHttpResponseWriter(SrsStSocket* io); + SrsHttpResponseWriter(ISrsProtocolReadWriter* io); virtual ~SrsHttpResponseWriter(); public: virtual srs_error_t final_request(); diff --git a/trunk/src/utest/srs_utest.hpp b/trunk/src/utest/srs_utest.hpp index 34ee90a0c..e128d350e 100644 --- a/trunk/src/utest/srs_utest.hpp +++ b/trunk/src/utest/srs_utest.hpp @@ -35,8 +35,10 @@ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. #include "gtest/gtest.h" #include +using namespace std; #include +#include // we add an empty macro for upp to show the smart tips. #define VOID @@ -61,6 +63,10 @@ extern srs_utime_t _srs_tmp_timeout; #define HELPER_ARRAY_INIT(buf, sz, val) \ for (int i = 0; i < (int)sz; i++) (buf)[i]=val +// Dump simple stream to string. +#define HELPER_BUFFER2STR(io) \ + string((const char*)(io)->bytes(), (size_t)(io)->length()) + // the asserts of gtest: // * {ASSERT|EXPECT}_EQ(expected, actual): Tests that expected == actual // * {ASSERT|EXPECT}_NE(v1, v2): Tests that v1 != v2 diff --git a/trunk/src/utest/srs_utest_http.cpp b/trunk/src/utest/srs_utest_http.cpp index 6b33b04d6..217071807 100644 --- a/trunk/src/utest/srs_utest_http.cpp +++ b/trunk/src/utest/srs_utest_http.cpp @@ -22,7 +22,84 @@ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include +#include +using namespace std; + #include +#include +#include + +class MockResponseWriter : virtual public ISrsHttpResponseWriter, virtual public ISrsHttpHeaderFilter +{ +public: + SrsHttpResponseWriter* w; + MockBufferIO io; +public: + MockResponseWriter(); + virtual ~MockResponseWriter(); +public: + virtual srs_error_t final_request(); + virtual SrsHttpHeader* header(); + virtual srs_error_t write(char* data, int size); + virtual srs_error_t writev(const iovec* iov, int iovcnt, ssize_t* pnwrite); + virtual void write_header(int code); +public: + virtual srs_error_t filter(SrsHttpHeader* h); +}; + +MockResponseWriter::MockResponseWriter() +{ + w = new SrsHttpResponseWriter(&io); + w->hf = this; +} + +MockResponseWriter::~MockResponseWriter() +{ + srs_freep(w); +} + +srs_error_t MockResponseWriter::final_request() +{ + return w->final_request(); +} + +SrsHttpHeader* MockResponseWriter::header() +{ + return w->header(); +} + +srs_error_t MockResponseWriter::write(char* data, int size) +{ + return w->write(data, size); +} + +srs_error_t MockResponseWriter::writev(const iovec* iov, int iovcnt, ssize_t* pnwrite) +{ + return w->writev(iov, iovcnt, pnwrite); +} + +void MockResponseWriter::write_header(int code) +{ + w->write_header(code); +} + +srs_error_t MockResponseWriter::filter(SrsHttpHeader* h) +{ + h->del("Content-Type"); + h->del("Server"); + h->del("Connection"); + return srs_success; +} + +string mock_http_response(int status, string content) +{ + stringstream ss; + ss << "HTTP/1.1 " << status << " " << srs_generate_http_status_text(status) << "\r\n" + << "Content-Length: " << content.length() << "\r\n" + << "\r\n" + << content; + return ss.str(); +} VOID TEST(ProtocolHTTPTest, StatusCode2Text) { @@ -40,3 +117,14 @@ VOID TEST(ProtocolHTTPTest, ResponseDetect) { EXPECT_STREQ("application/octet-stream", srs_go_http_detect(NULL, 0).c_str()); } + +VOID TEST(ProtocolHTTPTest, ResponseHTTPError) +{ + srs_error_t err; + + if (true) { + MockResponseWriter w; + HELPER_EXPECT_SUCCESS(srs_go_http_error(&w, SRS_CONSTS_HTTP_Found)); + EXPECT_STREQ(mock_http_response(302,"Found").c_str(), HELPER_BUFFER2STR(&w.io.out_buffer).c_str()); + } +}