8
8
9
9
#include " openssl/sha.h" // Sha-1 hash
10
10
11
+ #include < map>
11
12
#include < string.h>
12
- #include < vector>
13
13
14
14
#define ACCEPT_KEY_LENGTH base64_encoded_size (20 )
15
15
#define BUFFER_GROWTH_CHUNK_SIZE 1024
@@ -63,7 +63,7 @@ class ProtocolHandler {
63
63
virtual void Write (const std::vector<char > data) = 0;
64
64
virtual void CancelHandshake () = 0;
65
65
66
- std::string GetHost ();
66
+ std::string GetHost () const ;
67
67
68
68
InspectorSocket* inspector () {
69
69
return inspector_;
@@ -160,6 +160,48 @@ static void generate_accept_string(const std::string& client_key,
160
160
node::base64_encode (hash, sizeof (hash), *buffer, sizeof (*buffer));
161
161
}
162
162
163
+ static bool IsOneOf (const std::string& host,
164
+ const std::vector<std::string>& hosts) {
165
+ for (const std::string& candidate : hosts) {
166
+ if (node::StringEqualNoCase (host.data (), candidate.data ()))
167
+ return true ;
168
+ }
169
+ return false ;
170
+ }
171
+
172
+ static std::string TrimPort (const std::string& host) {
173
+ size_t last_colon_pos = host.rfind (" :" );
174
+ if (last_colon_pos == std::string::npos)
175
+ return host;
176
+ size_t bracket = host.rfind (" ]" );
177
+ if (bracket == std::string::npos || last_colon_pos > bracket)
178
+ return host.substr (0 , last_colon_pos);
179
+ return host;
180
+ }
181
+
182
+ static bool IsIPAddress (const std::string& host) {
183
+ if (host.length () >= 4 && host.front () == ' [' && host.back () == ' ]' )
184
+ return true ;
185
+ int quads = 0 ;
186
+ for (char c : host) {
187
+ if (c == ' .' )
188
+ quads++;
189
+ else if (!isdigit (c))
190
+ return false ;
191
+ }
192
+ return quads == 3 ;
193
+ }
194
+
195
+ // This is a value coming from the interface, it can only be IPv4 or IPv6
196
+ // address string.
197
+ static bool IsIPv4Localhost (const std::string& host) {
198
+ std::string v6_tunnel_prefix = " ::ffff:" ;
199
+ if (host.substr (0 , v6_tunnel_prefix.length ()) == v6_tunnel_prefix)
200
+ return IsIPv4Localhost (host.substr (v6_tunnel_prefix.length ()));
201
+ std::string localhost_net = " 127." ;
202
+ return host.substr (0 , localhost_net.length ()) == localhost_net;
203
+ }
204
+
163
205
// Constants for hybi-10 frame format.
164
206
165
207
typedef int OpCode;
@@ -298,7 +340,6 @@ static ws_decode_result decode_frame_hybi17(const std::vector<char>& buffer,
298
340
return closed ? FRAME_CLOSE : FRAME_OK;
299
341
}
300
342
301
-
302
343
// WS protocol
303
344
class WsHandler : public ProtocolHandler {
304
345
public:
@@ -400,17 +441,16 @@ class WsHandler : public ProtocolHandler {
400
441
// HTTP protocol
401
442
class HttpEvent {
402
443
public:
403
- HttpEvent (const std::string& path, bool upgrade,
404
- bool isGET, const std::string& ws_key) : path(path),
405
- upgrade (upgrade),
406
- isGET(isGET),
407
- ws_key(ws_key) { }
444
+ HttpEvent (const std::string& path, bool upgrade, bool isGET,
445
+ const std::string& ws_key, const std::string& host)
446
+ : path(path), upgrade(upgrade), isGET(isGET), ws_key(ws_key),
447
+ host (host) { }
408
448
409
449
std::string path;
410
450
bool upgrade;
411
451
bool isGET;
412
452
std::string ws_key;
413
- std::string current_header_ ;
453
+ std::string host ;
414
454
};
415
455
416
456
class HttpHandler : public ProtocolHandler {
@@ -472,18 +512,17 @@ class HttpHandler : public ProtocolHandler {
472
512
std::vector<HttpEvent> events;
473
513
std::swap (events, events_);
474
514
for (const HttpEvent& event : events) {
475
- bool shouldContinue = event.isGET && !event.upgrade ;
476
- if (!event.isGET ) {
515
+ if (!IsAllowedHost (event.host ) || !event.isGET ) {
477
516
CancelHandshake ();
517
+ return ;
478
518
} else if (!event.upgrade ) {
479
519
delegate ()->OnHttpGet (event.path );
480
520
} else if (event.ws_key .empty ()) {
481
521
CancelHandshake ();
522
+ return ;
482
523
} else {
483
524
delegate ()->OnSocketUpgrade (event.path , event.ws_key );
484
525
}
485
- if (!shouldContinue)
486
- return ;
487
526
}
488
527
}
489
528
@@ -504,16 +543,9 @@ class HttpHandler : public ProtocolHandler {
504
543
}
505
544
506
545
static int OnHeaderValue (http_parser* parser, const char * at, size_t length) {
507
- static const char SEC_WEBSOCKET_KEY_HEADER[] = " Sec-WebSocket-Key" ;
508
546
HttpHandler* handler = From (parser);
509
547
handler->parsing_value_ = true ;
510
- if (handler->current_header_ .size () ==
511
- sizeof (SEC_WEBSOCKET_KEY_HEADER) - 1 &&
512
- node::StringEqualNoCaseN (handler->current_header_ .data (),
513
- SEC_WEBSOCKET_KEY_HEADER,
514
- sizeof (SEC_WEBSOCKET_KEY_HEADER) - 1 )) {
515
- handler->ws_key_ .append (at, length);
516
- }
548
+ handler->headers_ [handler->current_header_ ].append (at, length);
517
549
return 0 ;
518
550
}
519
551
@@ -540,23 +572,53 @@ class HttpHandler : public ProtocolHandler {
540
572
static int OnMessageComplete (http_parser* parser) {
541
573
// Event needs to be fired after the parser is done.
542
574
HttpHandler* handler = From (parser);
543
- handler->events_ .push_back (HttpEvent (handler->path_ , parser->upgrade ,
544
- parser->method == HTTP_GET,
545
- handler->ws_key_ ));
575
+ handler->events_ .push_back (
576
+ HttpEvent (handler->path_ , parser->upgrade , parser->method == HTTP_GET,
577
+ handler->HeaderValue (" Sec-WebSocket-Key" ),
578
+ handler->HeaderValue (" Host" )));
546
579
handler->path_ = " " ;
547
- handler->ws_key_ = " " ;
548
580
handler->parsing_value_ = false ;
581
+ handler->headers_ .clear ();
549
582
handler->current_header_ = " " ;
550
-
551
583
return 0 ;
552
584
}
553
585
586
+ std::string HeaderValue (const std::string& header) const {
587
+ bool header_found = false ;
588
+ std::string value;
589
+ for (const auto & header_value : headers_) {
590
+ if (node::StringEqualNoCaseN (header_value.first .data (), header.data (),
591
+ header.length ())) {
592
+ if (header_found)
593
+ return " " ;
594
+ value = header_value.second ;
595
+ header_found = true ;
596
+ }
597
+ }
598
+ return value;
599
+ }
600
+
601
+ bool IsAllowedHost (const std::string& host_with_port) const {
602
+ std::string host = TrimPort (host_with_port);
603
+ if (host.empty ())
604
+ return false ;
605
+ if (IsIPAddress (host))
606
+ return true ;
607
+ std::string socket_host = GetHost ();
608
+ if (IsIPv4Localhost (socket_host)) {
609
+ return IsOneOf (host, { " localhost" });
610
+ } else if (socket_host == " ::1" ) {
611
+ return IsOneOf (host, { " localhost" , " localhost6" });
612
+ }
613
+ return true ;
614
+ }
615
+
554
616
bool parsing_value_;
555
617
http_parser parser_;
556
618
http_parser_settings parser_settings;
557
619
std::vector<HttpEvent> events_;
558
620
std::string current_header_;
559
- std::string ws_key_ ;
621
+ std::map<std:: string, std::string> headers_ ;
560
622
std::string path_;
561
623
};
562
624
@@ -579,7 +641,7 @@ InspectorSocket::Delegate* ProtocolHandler::delegate() {
579
641
return tcp_->delegate ();
580
642
}
581
643
582
- std::string ProtocolHandler::GetHost () {
644
+ std::string ProtocolHandler::GetHost () const {
583
645
char ip[INET6_ADDRSTRLEN];
584
646
sockaddr_storage addr;
585
647
int len = sizeof (addr);
@@ -622,8 +684,6 @@ TcpHolder::Pointer TcpHolder::Accept(
622
684
if (err == 0 ) {
623
685
return { result, DisconnectAndDispose };
624
686
} else {
625
- fprintf (stderr, " [%s:%d@%s]\n " , __FILE__, __LINE__, __FUNCTION__);
626
-
627
687
delete result;
628
688
return { nullptr , nullptr };
629
689
}
0 commit comments