Allow customization of the Server header

This commit is contained in:
Andrew Rabert 2022-02-04 19:10:08 -05:00
parent faaa7bb760
commit 282f8258ca
2 changed files with 63 additions and 0 deletions

View file

@ -214,6 +214,28 @@ class Application(CallbackManager, web.Application):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(self, *args, **kwargs) super().__init__(self, *args, **kwargs)
service = self.settings.get('service')
if service:
parts = [service]
version = self.settings.get('version')
if version:
parts.append(version)
self.settings.setdefault('server_header', '/'.join(parts))
else:
self.settings.setdefault('server_header', None)
# use a closure to `self.settings` to allow for dynamic configuration
class ServerHeaderTransform(web.OutputTransform):
def transform_first_chunk(_, status_code, headers, chunk,
finishing):
value = self.settings.get('server_header')
if value is None:
headers.pop('Server', None)
else:
headers['Server'] = value
return status_code, headers, chunk
self.add_transform(ServerHeaderTransform)
def log_request(self, handler): def log_request(self, handler):
"""Customized access log function. """Customized access log function.

View file

@ -918,3 +918,44 @@ class AccessLogTests(sprockets.http.testing.SprocketsHttpTestCase):
with self.assertLogs(log.access_log, log_level) as context: with self.assertLogs(log.access_log, log_level) as context:
self.app.log_request(handler) self.app.log_request(handler)
self.assertEqual(context.records[0].levelno, log_level) self.assertEqual(context.records[0].levelno, log_level)
class ServerHeaderTests(sprockets.http.testing.SprocketsHttpTestCase):
def get_app(self):
self.app = sprockets.http.app.Application(
server_header='a/b/c')
return self.app
def test_reads_from_settings(self):
self.app.settings['server_header'] = 'a/b/c'
response = self.fetch('/')
self.assertEqual('a/b/c', response.headers['Server'])
self.app.settings['server_header'] = 'some server'
response = self.fetch('/')
self.assertEqual('some server', response.headers['Server'])
self.app.settings['server_header'] = None
response = self.fetch('/')
self.assertNotIn('Server', response.headers)
self.app.settings.pop('server_header')
response = self.fetch('/')
self.assertNotIn('Server', response.headers)
def test_defaults(self):
app = sprockets.http.app.Application()
self.assertIsNone(app.settings['server_header'])
app = sprockets.http.app.Application(service='myservice',
version='myversion')
self.assertEqual('myservice/myversion', app.settings['server_header'])
app = sprockets.http.app.Application(service='myservice',
version=None)
self.assertEqual('myservice', app.settings['server_header'])
app = sprockets.http.app.Application(service=None,
version='myversion')
self.assertIsNone(app.settings['server_header'])