diff --git a/redash/authentication/remote_user_auth.py b/redash/authentication/remote_user_auth.py index f1b72b11..027da027 100644 --- a/redash/authentication/remote_user_auth.py +++ b/redash/authentication/remote_user_auth.py @@ -2,19 +2,20 @@ import logging from flask import redirect, url_for, Blueprint, request from redash.authentication.google_oauth import create_and_login_user from redash.authentication.org_resolving import current_org +from redash.handlers.base import org_scoped_rule from redash import settings logger = logging.getLogger('remote_user_auth') blueprint = Blueprint('remote_user_auth', __name__) -@blueprint.route("/remote_user/login") -def login(): +@blueprint.route(org_scoped_rule("/remote_user/login")) +def login(org_slug=None): next_path = request.args.get('next') if not settings.REMOTE_USER_LOGIN_ENABLED: logger.error("Cannot use remote user for login without being enabled in settings") - return redirect(url_for('redash.index', next=next_path)) + return redirect(url_for('redash.index', next=next_path, org_slug=org_slug)) email = request.headers.get(settings.REMOTE_USER_HEADER) @@ -27,8 +28,8 @@ def login(): if not email: logger.error("Cannot use remote user for login when it's not provided in the request (looked in headers['" + settings.REMOTE_USER_HEADER + "'])") - return redirect(url_for('redash.index', next=next_path)) + return redirect(url_for('redash.index', next=next_path, org_slug=org_slug)) logger.info("Logging in " + email + " via remote user") create_and_login_user(current_org, email, email) - return redirect(next_path or url_for('redash.index'), code=302) + return redirect(next_path or url_for('redash.index', org_slug=org_slug), code=302) diff --git a/tests/__init__.py b/tests/__init__.py index be66c5d9..5e5b486f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -91,17 +91,17 @@ class BaseTestCase(TestCase): return response - def get_request(self, path, org=None): + def get_request(self, path, org=None, headers=None): if org: path = "/{}{}".format(org.slug, path) - return self.client.get(path) + return self.client.get(path, headers=headers) - def post_request(self, path, data=None, org=None): + def post_request(self, path, data=None, org=None, headers=None): if org: path = "/{}{}".format(org.slug, path) - return self.client.post(path, data=data) + return self.client.post(path, data=data, headers=headers) def assertResponseEqual(self, expected, actual): for k, v in expected.iteritems(): diff --git a/tests/test_authentication.py b/tests/test_authentication.py index f527e94c..3b5fa1c1 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,7 +1,9 @@ +import os import time from flask import request from mock import patch +from sqlalchemy.orm.exc import NoResultFound from tests import BaseTestCase from redash import models, settings @@ -181,3 +183,88 @@ class TestGetLoginUrl(BaseTestCase): def test_when_multi_org_enabled_and_org_doesnt_exist(self): with self.app.test_request_context('/{}_notexists/'.format(self.factory.org.slug)): self.assertEqual(get_login_url(next=None), '/') + +class TestRemoteUserAuth(BaseTestCase): + DEFAULT_SETTING_OVERRIDES = { + 'REDASH_REMOTE_USER_LOGIN_ENABLED': 'true' + } + + def setUp(self): + # Apply default setting overrides to every test + self.override_settings(None) + + super(TestRemoteUserAuth, self).setUp() + + def override_settings(self, overrides): + """Override settings for testing purposes. + + This helper method can be used to override specific environmental + variables to enable / disable Re:Dash features for the duration + of the test. + + Note that these overrides only affect code that checks the value of + the setting at runtime. It doesn't affect code that only checks the + value during program initialization. + + :param dict overrides: a dict of environmental variables to override + when the settings are reloaded + """ + variables = self.DEFAULT_SETTING_OVERRIDES.copy() + variables.update(overrides or {}) + with patch.dict(os.environ, variables): + reload(settings) + + # Queue a cleanup routine that reloads the settings without overrides + # once the test ends + self.addCleanup(lambda: reload(settings)) + + def assert_correct_user_attributes(self, user, email='test@example.com', name='test@example.com', groups=None, org=None): + """Helper to assert that the user attributes are correct.""" + groups = groups or [] + if self.factory.org.default_group.id not in groups: + groups.append(self.factory.org.default_group.id) + + self.assertIsNotNone(user) + self.assertEqual(user.email, email) + self.assertEqual(user.name, name) + self.assertEqual(user.org, org or self.factory.org) + self.assertItemsEqual(user.group_ids, groups) + + def get_test_user(self, email='test@example.com', org=None): + """Helper to fetch an user from the database.""" + + # Expire all cached objects to ensure these values are read directly + # from the database. + models.db.session.expire_all() + + return models.User.get_by_email_and_org(email, org or self.factory.org) + + def test_remote_login_disabled(self): + self.override_settings({ + 'REDASH_REMOTE_USER_LOGIN_ENABLED': 'false' + }) + + self.get_request('/remote_user/login', org=self.factory.org, headers={ + 'X-Forwarded-Remote-User': 'test@example.com' + }) + + with self.assertRaises(NoResultFound): + self.get_test_user() + + def test_remote_login_default_header(self): + self.get_request('/remote_user/login', org=self.factory.org, headers={ + 'X-Forwarded-Remote-User': 'test@example.com' + }) + + self.assert_correct_user_attributes(self.get_test_user()) + + def test_remote_login_custom_header(self): + self.override_settings({ + 'REDASH_REMOTE_USER_HEADER': 'X-Custom-User' + }) + + self.get_request('/remote_user/login', org=self.factory.org, headers={ + 'X-Custom-User': 'test@example.com' + }) + + self.assert_correct_user_attributes(self.get_test_user())