diff --git a/sprockets/http/app.py b/sprockets/http/app.py index ae8c1af..0940b65 100644 --- a/sprockets/http/app.py +++ b/sprockets/http/app.py @@ -214,6 +214,28 @@ class Application(CallbackManager, web.Application): def __init__(self, *args, **kwargs): super().__init__(self, *args, **kwargs) + service = self.settings.get('service') + if service: + parts = [service] + version = self.settings.get('version') + if version: + parts.append(version) + self.settings.setdefault('server_header', '/'.join(parts)) + else: + self.settings.setdefault('server_header', None) + + # use a closure to `self.settings` to allow for dynamic configuration + class ServerHeaderTransform(web.OutputTransform): + def transform_first_chunk(_, status_code, headers, chunk, + finishing): + value = self.settings.get('server_header') + if value is None: + headers.pop('Server', None) + else: + headers['Server'] = value + return status_code, headers, chunk + + self.add_transform(ServerHeaderTransform) def log_request(self, handler): """Customized access log function. diff --git a/tests.py b/tests.py index 0cbcf46..71db30b 100644 --- a/tests.py +++ b/tests.py @@ -918,3 +918,44 @@ class AccessLogTests(sprockets.http.testing.SprocketsHttpTestCase): with self.assertLogs(log.access_log, log_level) as context: self.app.log_request(handler) self.assertEqual(context.records[0].levelno, log_level) + + +class ServerHeaderTests(sprockets.http.testing.SprocketsHttpTestCase): + + def get_app(self): + self.app = sprockets.http.app.Application( + server_header='a/b/c') + return self.app + + def test_reads_from_settings(self): + self.app.settings['server_header'] = 'a/b/c' + response = self.fetch('/') + self.assertEqual('a/b/c', response.headers['Server']) + + self.app.settings['server_header'] = 'some server' + response = self.fetch('/') + self.assertEqual('some server', response.headers['Server']) + + self.app.settings['server_header'] = None + response = self.fetch('/') + self.assertNotIn('Server', response.headers) + + self.app.settings.pop('server_header') + response = self.fetch('/') + self.assertNotIn('Server', response.headers) + + def test_defaults(self): + app = sprockets.http.app.Application() + self.assertIsNone(app.settings['server_header']) + + app = sprockets.http.app.Application(service='myservice', + version='myversion') + self.assertEqual('myservice/myversion', app.settings['server_header']) + + app = sprockets.http.app.Application(service='myservice', + version=None) + self.assertEqual('myservice', app.settings['server_header']) + + app = sprockets.http.app.Application(service=None, + version='myversion') + self.assertIsNone(app.settings['server_header'])