mirror of
https://github.com/valitydev/redash.git
synced 2024-11-07 01:25:16 +00:00
Use a custom EmailType/CaseInsensitiveComparator
sqlalchemy doesn't appear to have a case sensitive column, so we craft a type that can have the semantics of case insensitive comparison. Ref http://docs.sqlalchemy.org/en/rel_1_1/orm/extensions/hybrid.html#building-custom-comparators Signed-off-by: Mike Fiedler <miketheman@gmail.com>
This commit is contained in:
parent
9076715d84
commit
be232543d6
@ -112,7 +112,8 @@ def delete(email, organization=None):
|
|||||||
models.User.org == org.id,
|
models.User.org == org.id,
|
||||||
).delete()
|
).delete()
|
||||||
else:
|
else:
|
||||||
deleted_count = models.User.query.filter(models.User.email == email).delete()
|
deleted_count = models.User.query.filter(models.User.email == email).delete(
|
||||||
|
synchronize_session=False)
|
||||||
models.db.session.commit()
|
models.db.session.commit()
|
||||||
print("Deleted %d users." % deleted_count)
|
print("Deleted %d users." % deleted_count)
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ def forgot_password(org_slug=None):
|
|||||||
submitted = False
|
submitted = False
|
||||||
if request.method == 'POST' and request.form['email']:
|
if request.method == 'POST' and request.form['email']:
|
||||||
submitted = True
|
submitted = True
|
||||||
email = request.form['email'].lower()
|
email = request.form['email']
|
||||||
try:
|
try:
|
||||||
org = current_org._get_current_object()
|
org = current_org._get_current_object()
|
||||||
user = models.User.get_by_email_and_org(email, org)
|
user = models.User.get_by_email_and_org(email, org)
|
||||||
@ -119,7 +119,7 @@ def login(org_slug=None):
|
|||||||
if request.method == 'POST':
|
if request.method == 'POST':
|
||||||
try:
|
try:
|
||||||
org = current_org._get_current_object()
|
org = current_org._get_current_object()
|
||||||
user = models.User.get_by_email_and_org(request.form['email'].lower(), org)
|
user = models.User.get_by_email_and_org(request.form['email'], org)
|
||||||
if user and user.verify_password(request.form['password']):
|
if user and user.verify_password(request.form['password']):
|
||||||
remember = ('remember' in request.form)
|
remember = ('remember' in request.form)
|
||||||
login_user(user, remember=remember)
|
login_user(user, remember=remember)
|
||||||
|
@ -30,7 +30,7 @@ class UserListResource(BaseResource):
|
|||||||
|
|
||||||
user = models.User(org=self.current_org,
|
user = models.User(org=self.current_org,
|
||||||
name=req['name'],
|
name=req['name'],
|
||||||
email=req['email'].lower(),
|
email=req['email'],
|
||||||
group_ids=[self.current_org.default_group.id])
|
group_ids=[self.current_org.default_group.id])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -22,6 +22,7 @@ from redash.permissions import has_access, view_only
|
|||||||
from redash.query_runner import (get_configuration_schema_for_query_runner_type,
|
from redash.query_runner import (get_configuration_schema_for_query_runner_type,
|
||||||
get_query_runner)
|
get_query_runner)
|
||||||
from redash.utils import generate_token, json_dumps
|
from redash.utils import generate_token, json_dumps
|
||||||
|
from redash.utils.comparators import CaseInsensitiveComparator
|
||||||
from redash.utils.configuration import ConfigurationContainer
|
from redash.utils.configuration import ConfigurationContainer
|
||||||
from sqlalchemy import distinct, or_
|
from sqlalchemy import distinct, or_
|
||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
@ -348,12 +349,32 @@ class Group(db.Model, BelongsToOrgMixin):
|
|||||||
return unicode(self.id)
|
return unicode(self.id)
|
||||||
|
|
||||||
|
|
||||||
|
class EmailType(TypeDecorator):
|
||||||
|
"""
|
||||||
|
A case-insensitive string
|
||||||
|
"""
|
||||||
|
impl = db.String
|
||||||
|
comparator_factory = CaseInsensitiveComparator
|
||||||
|
|
||||||
|
def __init__(self, length=320, *args, **kwargs):
|
||||||
|
super(EmailType, self).__init__(length=length, *args, **kwargs)
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
if value is not None:
|
||||||
|
return value.lower()
|
||||||
|
return value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def python_type(self):
|
||||||
|
return self.impl.type.python_type
|
||||||
|
|
||||||
|
|
||||||
class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin):
|
class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin):
|
||||||
id = Column(db.Integer, primary_key=True)
|
id = Column(db.Integer, primary_key=True)
|
||||||
org_id = Column(db.Integer, db.ForeignKey('organizations.id'))
|
org_id = Column(db.Integer, db.ForeignKey('organizations.id'))
|
||||||
org = db.relationship(Organization, backref=db.backref("users", lazy="dynamic"))
|
org = db.relationship(Organization, backref=db.backref("users", lazy="dynamic"))
|
||||||
name = Column(db.String(320))
|
name = Column(db.String(320))
|
||||||
email = Column(db.String(320))
|
email = Column(EmailType)
|
||||||
password_hash = Column(db.String(128), nullable=True)
|
password_hash = Column(db.String(128), nullable=True)
|
||||||
# XXX replace with association table
|
# XXX replace with association table
|
||||||
group_ids = Column('groups', MutableList.as_mutable(postgresql.ARRAY(db.Integer)), nullable=True)
|
group_ids = Column('groups', MutableList.as_mutable(postgresql.ARRAY(db.Integer)), nullable=True)
|
||||||
@ -365,6 +386,7 @@ class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCh
|
|||||||
__table_args__ = (db.Index('users_org_id_email', 'org_id', 'email', unique=True),)
|
__table_args__ = (db.Index('users_org_id_email', 'org_id', 'email', unique=True),)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs["email"] = kwargs.get("email", None).lower()
|
||||||
super(User, self).__init__(*args, **kwargs)
|
super(User, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def to_dict(self, with_api_key=False):
|
def to_dict(self, with_api_key=False):
|
||||||
|
7
redash/utils/comparators.py
Normal file
7
redash/utils/comparators.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from sqlalchemy import func
|
||||||
|
from sqlalchemy.ext.hybrid import Comparator
|
||||||
|
|
||||||
|
|
||||||
|
class CaseInsensitiveComparator(Comparator):
|
||||||
|
def __eq__(self, other):
|
||||||
|
return func.lower(self.__clause_element__()) == func.lower(other)
|
@ -25,3 +25,23 @@ class TestUserFindByEmail(BaseTestCase):
|
|||||||
users = User.find_by_email(user.email)
|
users = User.find_by_email(user.email)
|
||||||
self.assertIn(user, users)
|
self.assertIn(user, users)
|
||||||
self.assertIn(user2, users)
|
self.assertIn(user2, users)
|
||||||
|
|
||||||
|
def test_finds_users_case_insensitive(self):
|
||||||
|
user = self.factory.create_user(email='test@example.com')
|
||||||
|
|
||||||
|
users = User.find_by_email('test@EXAMPLE.com')
|
||||||
|
self.assertIn(user, users)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserGetByEmailAndOrg(BaseTestCase):
|
||||||
|
def test_get_user_by_email_and_org(self):
|
||||||
|
user = self.factory.create_user(email='test@example.com')
|
||||||
|
|
||||||
|
found_user = User.get_by_email_and_org(user.email, user.org)
|
||||||
|
self.assertEqual(user, found_user)
|
||||||
|
|
||||||
|
def test_get_user_by_email_and_org_case_insensitive(self):
|
||||||
|
user = self.factory.create_user(email='test@example.com')
|
||||||
|
|
||||||
|
found_user = User.get_by_email_and_org("TEST@example.com", user.org)
|
||||||
|
self.assertEqual(user, found_user)
|
||||||
|
Loading…
Reference in New Issue
Block a user