Use a custom EmailType/CaseInsensitiveComparator

sqlalchemy doesn't appear to have a case sensitive column, so we craft a
type that can have the semantics of case insensitive comparison.

Ref http://docs.sqlalchemy.org/en/rel_1_1/orm/extensions/hybrid.html#building-custom-comparators

Signed-off-by: Mike Fiedler <miketheman@gmail.com>
This commit is contained in:
Mike Fiedler 2017-10-10 17:35:25 -04:00
parent 9076715d84
commit be232543d6
No known key found for this signature in database
GPG Key ID: 5E1134F2FAF158B9
6 changed files with 55 additions and 5 deletions

View File

@ -112,7 +112,8 @@ def delete(email, organization=None):
models.User.org == org.id, models.User.org == org.id,
).delete() ).delete()
else: else:
deleted_count = models.User.query.filter(models.User.email == email).delete() deleted_count = models.User.query.filter(models.User.email == email).delete(
synchronize_session=False)
models.db.session.commit() models.db.session.commit()
print("Deleted %d users." % deleted_count) print("Deleted %d users." % deleted_count)

View File

@ -80,7 +80,7 @@ def forgot_password(org_slug=None):
submitted = False submitted = False
if request.method == 'POST' and request.form['email']: if request.method == 'POST' and request.form['email']:
submitted = True submitted = True
email = request.form['email'].lower() email = request.form['email']
try: try:
org = current_org._get_current_object() org = current_org._get_current_object()
user = models.User.get_by_email_and_org(email, org) user = models.User.get_by_email_and_org(email, org)
@ -119,7 +119,7 @@ def login(org_slug=None):
if request.method == 'POST': if request.method == 'POST':
try: try:
org = current_org._get_current_object() org = current_org._get_current_object()
user = models.User.get_by_email_and_org(request.form['email'].lower(), org) user = models.User.get_by_email_and_org(request.form['email'], org)
if user and user.verify_password(request.form['password']): if user and user.verify_password(request.form['password']):
remember = ('remember' in request.form) remember = ('remember' in request.form)
login_user(user, remember=remember) login_user(user, remember=remember)

View File

@ -30,7 +30,7 @@ class UserListResource(BaseResource):
user = models.User(org=self.current_org, user = models.User(org=self.current_org,
name=req['name'], name=req['name'],
email=req['email'].lower(), email=req['email'],
group_ids=[self.current_org.default_group.id]) group_ids=[self.current_org.default_group.id])
try: try:

View File

@ -22,6 +22,7 @@ from redash.permissions import has_access, view_only
from redash.query_runner import (get_configuration_schema_for_query_runner_type, from redash.query_runner import (get_configuration_schema_for_query_runner_type,
get_query_runner) get_query_runner)
from redash.utils import generate_token, json_dumps from redash.utils import generate_token, json_dumps
from redash.utils.comparators import CaseInsensitiveComparator
from redash.utils.configuration import ConfigurationContainer from redash.utils.configuration import ConfigurationContainer
from sqlalchemy import distinct, or_ from sqlalchemy import distinct, or_
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
@ -348,12 +349,32 @@ class Group(db.Model, BelongsToOrgMixin):
return unicode(self.id) return unicode(self.id)
class EmailType(TypeDecorator):
"""
A case-insensitive string
"""
impl = db.String
comparator_factory = CaseInsensitiveComparator
def __init__(self, length=320, *args, **kwargs):
super(EmailType, self).__init__(length=length, *args, **kwargs)
def process_bind_param(self, value, dialect):
if value is not None:
return value.lower()
return value
@property
def python_type(self):
return self.impl.type.python_type
class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin): class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin):
id = Column(db.Integer, primary_key=True) id = Column(db.Integer, primary_key=True)
org_id = Column(db.Integer, db.ForeignKey('organizations.id')) org_id = Column(db.Integer, db.ForeignKey('organizations.id'))
org = db.relationship(Organization, backref=db.backref("users", lazy="dynamic")) org = db.relationship(Organization, backref=db.backref("users", lazy="dynamic"))
name = Column(db.String(320)) name = Column(db.String(320))
email = Column(db.String(320)) email = Column(EmailType)
password_hash = Column(db.String(128), nullable=True) password_hash = Column(db.String(128), nullable=True)
# XXX replace with association table # XXX replace with association table
group_ids = Column('groups', MutableList.as_mutable(postgresql.ARRAY(db.Integer)), nullable=True) group_ids = Column('groups', MutableList.as_mutable(postgresql.ARRAY(db.Integer)), nullable=True)
@ -365,6 +386,7 @@ class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCh
__table_args__ = (db.Index('users_org_id_email', 'org_id', 'email', unique=True),) __table_args__ = (db.Index('users_org_id_email', 'org_id', 'email', unique=True),)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs["email"] = kwargs.get("email", None).lower()
super(User, self).__init__(*args, **kwargs) super(User, self).__init__(*args, **kwargs)
def to_dict(self, with_api_key=False): def to_dict(self, with_api_key=False):

View File

@ -0,0 +1,7 @@
from sqlalchemy import func
from sqlalchemy.ext.hybrid import Comparator
class CaseInsensitiveComparator(Comparator):
def __eq__(self, other):
return func.lower(self.__clause_element__()) == func.lower(other)

View File

@ -25,3 +25,23 @@ class TestUserFindByEmail(BaseTestCase):
users = User.find_by_email(user.email) users = User.find_by_email(user.email)
self.assertIn(user, users) self.assertIn(user, users)
self.assertIn(user2, users) self.assertIn(user2, users)
def test_finds_users_case_insensitive(self):
user = self.factory.create_user(email='test@example.com')
users = User.find_by_email('test@EXAMPLE.com')
self.assertIn(user, users)
class TestUserGetByEmailAndOrg(BaseTestCase):
def test_get_user_by_email_and_org(self):
user = self.factory.create_user(email='test@example.com')
found_user = User.get_by_email_and_org(user.email, user.org)
self.assertEqual(user, found_user)
def test_get_user_by_email_and_org_case_insensitive(self):
user = self.factory.create_user(email='test@example.com')
found_user = User.get_by_email_and_org("TEST@example.com", user.org)
self.assertEqual(user, found_user)