@@ -551,13 +551,17 @@ def send_response_only(self, code, message=None):
551551 (self .protocol_version , code , message )).encode (
552552 'latin-1' , 'strict' ))
553553
554- def send_header (self , keyword , value ):
554+ def send_header (self , keyword , value , * , _is_extra = False ):
555555 """Send a MIME header to the headers buffer."""
556556 if self .request_version != 'HTTP/0.9' :
557557 if not hasattr (self , '_headers_buffer' ):
558558 self ._headers_buffer = []
559559 self ._headers_buffer .append (
560560 ("%s: %s\r \n " % (keyword , value )).encode ('latin-1' , 'strict' ))
561+ if not hasattr (self , '_default_response_headers' ):
562+ self ._default_response_headers = []
563+ if not _is_extra :
564+ self ._default_response_headers .append ((keyword , value ))
561565
562566 if keyword .lower () == 'connection' :
563567 if value .lower () == 'close' :
@@ -575,6 +579,8 @@ def flush_headers(self):
575579 if hasattr (self , '_headers_buffer' ):
576580 self .wfile .write (b"" .join (self ._headers_buffer ))
577581 self ._headers_buffer = []
582+ if hasattr (self , '_default_response_headers' ):
583+ self ._default_response_headers = []
578584
579585 def _colorize_request (self , code , size , t ):
580586 try :
@@ -736,10 +742,11 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
736742 '.xz' : 'application/x-xz' ,
737743 }
738744
739- def __init__ (self , * args , directory = None , ** kwargs ):
745+ def __init__ (self , * args , directory = None , extra_response_headers = None , ** kwargs ):
740746 if directory is None :
741747 directory = os .getcwd ()
742748 self .directory = os .fspath (directory )
749+ self .extra_response_headers = extra_response_headers
743750 super ().__init__ (* args , ** kwargs )
744751
745752 def do_GET (self ):
@@ -757,6 +764,16 @@ def do_HEAD(self):
757764 if f :
758765 f .close ()
759766
767+ def _send_extra_response_headers (self ):
768+ """Send the headers stored in self.extra_response_headers."""
769+ if self .extra_response_headers is not None :
770+ default_headers = {h .lower () for h , _ in self ._default_response_headers }
771+ for header , value in self .extra_response_headers :
772+ # Don't send the header if it's already sent
773+ # as part of the default response headers
774+ if header .lower () not in default_headers :
775+ self .send_header (header , value , _is_extra = True )
776+
760777 def send_head (self ):
761778 """Common code for GET and HEAD commands.
762779
@@ -839,6 +856,7 @@ def send_head(self):
839856 self .send_header ("Content-Length" , str (fs [6 ]))
840857 self .send_header ("Last-Modified" ,
841858 self .date_time_string (fs .st_mtime ))
859+ self ._send_extra_response_headers ()
842860 self .end_headers ()
843861 return f
844862 except :
@@ -903,6 +921,7 @@ def list_directory(self, path):
903921 self .send_response (HTTPStatus .OK )
904922 self .send_header ("Content-type" , "text/html; charset=%s" % enc )
905923 self .send_header ("Content-Length" , str (len (encoded )))
924+ self ._send_extra_response_headers ()
906925 self .end_headers ()
907926 return f
908927
@@ -1011,6 +1030,22 @@ def _get_best_family(*address):
10111030 return family , sockaddr
10121031
10131032
1033+ def _make_server (HandlerClass = BaseHTTPRequestHandler ,
1034+ ServerClass = ThreadingHTTPServer ,
1035+ protocol = "HTTP/1.0" , port = 8000 , bind = None ,
1036+ tls_cert = None , tls_key = None , tls_password = None ,
1037+ default_content_type = SimpleHTTPRequestHandler .default_content_type ):
1038+ ServerClass .address_family , addr = _get_best_family (bind , port )
1039+ HandlerClass .protocol_version = protocol
1040+ HandlerClass .default_content_type = default_content_type
1041+
1042+ if tls_cert :
1043+ return ServerClass (addr , HandlerClass , certfile = tls_cert ,
1044+ keyfile = tls_key , password = tls_password )
1045+ else :
1046+ return ServerClass (addr , HandlerClass )
1047+
1048+
10141049def test (HandlerClass = SimpleHTTPRequestHandler ,
10151050 ServerClass = ThreadingHTTPServer ,
10161051 protocol = "HTTP/1.0" , port = 8000 , bind = None ,
@@ -1019,19 +1054,13 @@ def test(HandlerClass=SimpleHTTPRequestHandler,
10191054 """Test the HTTP request handler class.
10201055
10211056 This runs an HTTP server on port 8000 (or the port argument).
1022-
10231057 """
1024- ServerClass .address_family , addr = _get_best_family (bind , port )
1025- HandlerClass .protocol_version = protocol
1026- HandlerClass .default_content_type = content_type
1027-
1028- if tls_cert :
1029- server = ServerClass (addr , HandlerClass , certfile = tls_cert ,
1030- keyfile = tls_key , password = tls_password )
1031- else :
1032- server = ServerClass (addr , HandlerClass )
1033-
1034- with server as httpd :
1058+ with _make_server (
1059+ HandlerClass = HandlerClass , ServerClass = ServerClass ,
1060+ protocol = protocol , port = port , bind = bind ,
1061+ tls_cert = tls_cert , tls_key = tls_key , tls_password = tls_password ,
1062+ default_content_type = content_type ,
1063+ ) as httpd :
10351064 host , port = httpd .socket .getsockname ()[:2 ]
10361065 url_host = f'[{ host } ]' if ':' in host else host
10371066 protocol = 'HTTPS' if tls_cert else 'HTTP'
@@ -1076,6 +1105,10 @@ def _main(args=None):
10761105 parser .add_argument ('port' , default = 8000 , type = int , nargs = '?' ,
10771106 help = 'bind to this port '
10781107 '(default: %(default)s)' )
1108+ parser .add_argument ('-H' , '--header' , nargs = 2 , action = 'append' ,
1109+ metavar = ('HEADER' , 'VALUE' ),
1110+ help = 'Add a custom response header '
1111+ '(can be specified multiple times)' )
10791112 args = parser .parse_args (args )
10801113
10811114 if not args .tls_cert and args .tls_key :
@@ -1104,7 +1137,8 @@ def server_bind(self):
11041137
11051138 def finish_request (self , request , client_address ):
11061139 self .RequestHandlerClass (request , client_address , self ,
1107- directory = args .directory )
1140+ directory = args .directory ,
1141+ extra_response_headers = args .header )
11081142
11091143 class HTTPDualStackServer (DualStackServerMixin , ThreadingHTTPServer ):
11101144 pass
0 commit comments