diff --git a/flask_security/core.py b/flask_security/core.py index 80d3e1db..f410f15b 100644 --- a/flask_security/core.py +++ b/flask_security/core.py @@ -25,8 +25,10 @@ from werkzeug.datastructures import ImmutableList from werkzeug.local import LocalProxy -from .forms import ChangePasswordForm, ConfirmRegisterForm, \ - ForgotPasswordForm, LoginForm, PasswordlessLoginForm, RegisterForm, \ +from .forms import ChangePasswordForm, \ + EmailConfirmRegisterForm, UsernameConfirmRegisterForm, \ + ForgotPasswordForm, EmailLoginForm, UsernameLoginForm, \ + PasswordlessLoginForm, EmailRegisterForm, UsernameRegisterForm, \ ResetPasswordForm, SendConfirmationForm from .utils import config_value as cv from .utils import _, get_config, hash_data, localize_callback, string_types, \ @@ -142,6 +144,8 @@ _('Invalid confirmation token.'), 'error'), 'EMAIL_ALREADY_ASSOCIATED': ( _('%(email)s is already associated with an account.'), 'error'), + 'USERNAME_ALREADY_IN_USE': ( + _('%(username)s is already in use.'), 'error'), 'PASSWORD_MISMATCH': ( _('Password does not match'), 'error'), 'RETYPE_PASSWORD_MISMATCH': ( @@ -177,6 +181,10 @@ _('Email not provided'), 'error'), 'INVALID_EMAIL_ADDRESS': ( _('Invalid email address'), 'error'), + 'USERNAME_NOT_PROVIDED': ( + _('Username not provided'), 'error'), + 'INVALID_USERNAME': ( + _('Invalid username'), 'error'), 'PASSWORD_NOT_PROVIDED': ( _('Password not provided'), 'error'), 'PASSWORD_NOT_SET': ( @@ -205,10 +213,21 @@ _('Please reauthenticate to access this page.'), 'info'), } -_default_forms = { - 'login_form': LoginForm, - 'confirm_register_form': ConfirmRegisterForm, - 'register_form': RegisterForm, +_default_email_forms = { + 'login_form': EmailLoginForm, + 'confirm_register_form': EmailConfirmRegisterForm, + 'register_form': EmailRegisterForm, + 'forgot_password_form': ForgotPasswordForm, + 'reset_password_form': ResetPasswordForm, + 'change_password_form': ChangePasswordForm, + 'send_confirmation_form': SendConfirmationForm, + 'passwordless_login_form': PasswordlessLoginForm, +} + +_default_username_forms = { + 'login_form': UsernameLoginForm, + 'confirm_register_form': UsernameConfirmRegisterForm, + 'register_form': UsernameRegisterForm, 'forgot_password_form': ForgotPasswordForm, 'reset_password_form': ResetPasswordForm, 'change_password_form': ChangePasswordForm, @@ -342,7 +361,16 @@ def _get_state(app, datastore, anonymous_user=None, **kwargs): _unauthorized_callback=None )) - for key, value in _default_forms.items(): + ident_attrs = app.config.get( + "SECURITY_USER_IDENTITY_ATTRIBUTES", + ["email"], + ) + if ident_attrs == ["email"]: + default_forms = _default_email_forms + else: + default_forms = _default_username_forms + + for key, value in default_forms.items(): if key not in kwargs or not kwargs[key]: kwargs[key] = value diff --git a/flask_security/datastore.py b/flask_security/datastore.py index c8109eb8..06d31939 100644 --- a/flask_security/datastore.py +++ b/flask_security/datastore.py @@ -118,7 +118,8 @@ def __init__(self, user_model, role_model): def _prepare_role_modify_args(self, user, role): if isinstance(user, string_types): - user = self.find_user(email=user) + user_kwargs = {attr: user for attr in get_identity_attributes()} + user = self.find_user(**user_kwargs) if isinstance(role, string_types): role = self.find_role(role) return user, role diff --git a/flask_security/forms.py b/flask_security/forms.py index 903b3636..c7cc9f73 100644 --- a/flask_security/forms.py +++ b/flask_security/forms.py @@ -28,6 +28,7 @@ _default_field_labels = { 'email': _('Email Address'), + 'username': _('Username'), 'password': _('Password'), 'remember_me': _('Remember Me'), 'login': _('Login'), @@ -61,12 +62,18 @@ class Email(ValidatorMixin, validators.Email): pass +class Regexp(ValidatorMixin, validators.Regexp): + pass + + class Length(ValidatorMixin, validators.Length): pass email_required = Required(message='EMAIL_NOT_PROVIDED') email_validator = Email(message='INVALID_EMAIL_ADDRESS') +username_required = Required(message='USERNAME_NOT_PROVIDED') +username_validator = Regexp(r"[A-Za-z0-9_]+", message='INVALID_USERNAME') password_required = Required(message='PASSWORD_NOT_PROVIDED') password_length = Length(min=6, max=128, message='PASSWORD_INVALID_LENGTH') @@ -81,12 +88,21 @@ def unique_user_email(form, field): raise ValidationError(msg) +def unique_user_username(form, field): + if _datastore.get_user(field.data) is not None: + msg = get_message('USERNAME_ALREADY_IN_USE', username=field.data)[0] + raise ValidationError(msg) + + def valid_user_email(form, field): form.user = _datastore.get_user(field.data) if form.user is None: raise ValidationError(get_message('USER_DOES_NOT_EXIST')[0]) +valid_user_username = valid_user_email + + class Form(BaseForm): def __init__(self, *args, **kwargs): if current_app.testing: @@ -94,6 +110,12 @@ def __init__(self, *args, **kwargs): super(Form, self).__init__(*args, **kwargs) +class IdentifierForm(Form): + def __init__(self, *args, **kwargs): + super(IdentifierForm, self).__init__(*args, **kwargs) + setattr(self, "identifier", getattr(self, self.identifier_field)) + + class EmailFormMixin(): email = StringField( get_form_field_label('email'), @@ -105,12 +127,41 @@ class UserEmailFormMixin(): email = StringField( get_form_field_label('email'), validators=[email_required, email_validator, valid_user_email]) + identifier_field = "email" class UniqueEmailFormMixin(): email = StringField( get_form_field_label('email'), validators=[email_required, email_validator, unique_user_email]) + identifier_field = "email" + + +class UsernameFormMixin(): + username = StringField( + get_form_field_label('username'), + validators=[username_required, username_validator]) + + +class UserUsernameFormMixin(): + user = None + username = StringField( + get_form_field_label('username'), + validators=[ + username_required, username_validator, valid_user_username + ] + ) + identifier_field = "username" + + +class UniqueUsernameFormMixin(): + username = StringField( + get_form_field_label('username'), + validators=[ + username_required, username_validator, unique_user_username + ] + ) + identifier_field = "username" class PasswordFormMixin(): @@ -150,11 +201,12 @@ def is_field_and_user_attr(member): hasattr(_datastore.user_model, member.name) fields = inspect.getmembers(form, is_field_and_user_attr) - return dict((key, value.data) for key, value in fields) + return dict((key, value.data) for key, value in fields + if key != "identifier") class SendConfirmationForm(Form, UserEmailFormMixin): - """The default forgot password form""" + """The default send confirmation form""" submit = SubmitField(get_form_field_label('send_confirmation')) @@ -172,20 +224,34 @@ def validate(self): return True -class ForgotPasswordForm(Form, UserEmailFormMixin): +class AbstractForgotPasswordForm(IdentifierForm): """The default forgot password form""" submit = SubmitField(get_form_field_label('recover_password')) def validate(self): - if not super(ForgotPasswordForm, self).validate(): + if not super(AbstractForgotPasswordForm, self).validate(): return False if requires_confirmation(self.user): - self.email.errors.append(get_message('CONFIRMATION_REQUIRED')[0]) + self.identifier.errors.append( + get_message('CONFIRMATION_REQUIRED')[0] + ) return False return True +class EmailForgotPasswordForm(AbstractForgotPasswordForm, UserEmailFormMixin): + pass + + +class UsernameForgotPasswordForm(AbstractForgotPasswordForm, + UserUsernameFormMixin): + pass + + +ForgotPasswordForm = EmailForgotPasswordForm + + class PasswordlessLoginForm(Form, UserEmailFormMixin): """The passwordless login form""" @@ -203,18 +269,16 @@ def validate(self): return True -class LoginForm(Form, NextFormMixin): +class AbstractLoginForm(IdentifierForm, NextFormMixin): """The default login form""" - email = StringField(get_form_field_label('email'), - validators=[Required(message='EMAIL_NOT_PROVIDED')]) password = PasswordField(get_form_field_label('password'), validators=[password_required]) remember = BooleanField(get_form_field_label('remember_me')) submit = SubmitField(get_form_field_label('login')) def __init__(self, *args, **kwargs): - super(LoginForm, self).__init__(*args, **kwargs) + super(AbstractLoginForm, self).__init__(*args, **kwargs) if not self.next.data: self.next.data = request.args.get('next', '') self.remember.default = config_value('DEFAULT_REMEMBER_ME') @@ -227,42 +291,77 @@ def __init__(self, *args, **kwargs): self.password.description = html def validate(self): - if not super(LoginForm, self).validate(): + if not super(AbstractLoginForm, self).validate(): return False - self.user = _datastore.get_user(self.email.data) + self.user = _datastore.get_user(self.identifier.data) if self.user is None: - self.email.errors.append(get_message('USER_DOES_NOT_EXIST')[0]) + self.identifier.errors.append( + get_message('USER_DOES_NOT_EXIST')[0] + ) return False if not self.user.password: - self.password.errors.append(get_message('PASSWORD_NOT_SET')[0]) + self.identifier.errors.append(get_message('PASSWORD_NOT_SET')[0]) return False if not verify_and_update_password(self.password.data, self.user): self.password.errors.append(get_message('INVALID_PASSWORD')[0]) return False if requires_confirmation(self.user): - self.email.errors.append(get_message('CONFIRMATION_REQUIRED')[0]) + self.identifier.errors.append( + get_message('CONFIRMATION_REQUIRED')[0] + ) return False if not self.user.is_active: - self.email.errors.append(get_message('DISABLED_ACCOUNT')[0]) + self.identifier.errors.append(get_message('DISABLED_ACCOUNT')[0]) return False return True -class ConfirmRegisterForm(Form, RegisterFormMixin, - UniqueEmailFormMixin, NewPasswordFormMixin): +class EmailLoginForm(AbstractLoginForm, UserEmailFormMixin): + pass + + +class UsernameLoginForm(AbstractLoginForm, UserUsernameFormMixin): + pass + + +LoginForm = EmailLoginForm + + +class EmailConfirmRegisterForm(IdentifierForm, RegisterFormMixin, + UniqueEmailFormMixin, NewPasswordFormMixin): pass -class RegisterForm(ConfirmRegisterForm, PasswordConfirmFormMixin, - NextFormMixin): +class UsernameConfirmRegisterForm(IdentifierForm, RegisterFormMixin, + UniqueUsernameFormMixin, + NewPasswordFormMixin): + pass + + +ConfirmRegisterForm = EmailConfirmRegisterForm + + +class EmailRegisterForm(EmailConfirmRegisterForm, PasswordConfirmFormMixin, + NextFormMixin): + def __init__(self, *args, **kwargs): + super(EmailRegisterForm, self).__init__(*args, **kwargs) + if not self.next.data: + self.next.data = request.args.get('next', '') + + +class UsernameRegisterForm(UsernameConfirmRegisterForm, + PasswordConfirmFormMixin, NextFormMixin): def __init__(self, *args, **kwargs): - super(RegisterForm, self).__init__(*args, **kwargs) + super(UsernameRegisterForm, self).__init__(*args, **kwargs) if not self.next.data: self.next.data = request.args.get('next', '') +RegisterForm = EmailRegisterForm + + class ResetPasswordForm(Form, NewPasswordFormMixin, PasswordConfirmFormMixin): """The default reset password form""" diff --git a/flask_security/templates/security/login_user.html b/flask_security/templates/security/login_user.html index 0962d2f5..2d5d513b 100644 --- a/flask_security/templates/security/login_user.html +++ b/flask_security/templates/security/login_user.html @@ -3,7 +3,7 @@

{{ _('Login') }}

{{ login_user_form.hidden_tag() }} - {{ render_field_with_errors(login_user_form.email) }} + {{ render_field_with_errors(login_user_form.identifier) }} {{ render_field_with_errors(login_user_form.password) }} {{ render_field_with_errors(login_user_form.remember) }} {{ render_field(login_user_form.next) }} diff --git a/flask_security/templates/security/register_user.html b/flask_security/templates/security/register_user.html index 0380c742..6f7af408 100644 --- a/flask_security/templates/security/register_user.html +++ b/flask_security/templates/security/register_user.html @@ -3,7 +3,7 @@

{{ _('Register') }}

{{ register_user_form.hidden_tag() }} - {{ render_field_with_errors(register_user_form.email) }} + {{ render_field_with_errors(register_user_form.identifier) }} {{ render_field_with_errors(register_user_form.password) }} {% if register_user_form.password_confirm %} {{ render_field_with_errors(register_user_form.password_confirm) }} diff --git a/tests/conftest.py b/tests/conftest.py index 0d263f86..5b0ab0dd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,8 +37,7 @@ def default(self, o): return BaseEncoder.default(self, o) -@pytest.fixture() -def app(request): +def create_fixture_app(keywords, identity_attrs): app = Flask(__name__) app.response_class = Response app.debug = True @@ -47,20 +46,20 @@ def app(request): app.config['LOGIN_DISABLED'] = False app.config['WTF_CSRF_ENABLED'] = False app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False + app.config["SECURITY_USER_IDENTITY_ATTRIBUTES"] = identity_attrs app.config['SECURITY_PASSWORD_SALT'] = 'salty' for opt in ['changeable', 'recoverable', 'registerable', 'trackable', 'passwordless', 'confirmable']: - app.config['SECURITY_' + opt.upper()] = opt in request.keywords + app.config['SECURITY_' + opt.upper()] = opt in keywords - if 'settings' in request.keywords: - for key, value in request.keywords['settings'].kwargs.items(): + if 'settings' in keywords: + for key, value in keywords['settings'].kwargs.items(): app.config['SECURITY_' + key.upper()] = value mail = Mail(app) - if 'babel' not in request.keywords or \ - request.keywords['babel'].args[0]: + if 'babel' not in keywords or keywords['babel'].args[0]: babel = Babel(app) app.babel = babel app.json_encoder = JSONEncoder @@ -135,19 +134,35 @@ def page_1(): return app +@pytest.fixture() +def username_app(request): + return create_fixture_app(request.keywords, ["username"]) + + +@pytest.fixture() +def email_app(request): + return create_fixture_app(request.keywords, ["email"]) + + +@pytest.fixture() +def app(email_app): # , username_app): + # TODO: this fixture should be parametrized, like the `datastore` fixture + return email_app + + @pytest.yield_fixture() -def mongoengine_datastore(app): +def mongoengine_email_datastore(email_app): from flask_mongoengine import MongoEngine db_name = 'flask_security_test_%s' % str(time.time()).replace('.', '_') - app.config['MONGODB_SETTINGS'] = { + email_app.config['MONGODB_SETTINGS'] = { 'db': db_name, 'host': 'localhost', 'port': 27017, 'alias': db_name } - db = MongoEngine(app) + db = MongoEngine(email_app) class Role(db.Document, RoleMixin): name = db.StringField(required=True, unique=True, max_length=80) @@ -156,7 +171,7 @@ class Role(db.Document, RoleMixin): class User(db.Document, UserMixin): email = db.StringField(unique=True, max_length=255) - username = db.StringField(max_length=255) + username = db.StringField(max_length=255) # TODO: remove password = db.StringField(required=False, max_length=255) last_login_at = db.DateTimeField() current_login_at = db.DateTimeField() @@ -170,19 +185,67 @@ class User(db.Document, UserMixin): yield MongoEngineUserDatastore(db, User, Role) - with app.app_context(): + with email_app.app_context(): db.connection.drop_database(db_name) +@pytest.yield_fixture() +def mongoengine_username_datastore(username_app): + from flask_mongoengine import MongoEngine + + db_name = 'flask_security_test_%s' % str(time.time()).replace('.', '_') + username_app.config['MONGODB_SETTINGS'] = { + 'db': db_name, + 'host': 'localhost', + 'port': 27017, + 'alias': db_name + } + + db = MongoEngine(username_app) + + class Role(db.Document, RoleMixin): + name = db.StringField(required=True, unique=True, max_length=80) + description = db.StringField(max_length=255) + meta = {"db_alias": db_name} + + class User(db.Document, UserMixin): + username = db.StringField(unique=True, max_length=255) + password = db.StringField(required=False, max_length=255) + last_login_at = db.DateTimeField() + current_login_at = db.DateTimeField() + last_login_ip = db.StringField(max_length=100) + current_login_ip = db.StringField(max_length=100) + login_count = db.IntField() + active = db.BooleanField(default=True) + confirmed_at = db.DateTimeField() + roles = db.ListField(db.ReferenceField(Role), default=[]) + meta = {"db_alias": db_name} + + yield MongoEngineUserDatastore(db, User, Role) + + with username_app.app_context(): + db.connection.drop_database(db_name) + + +@pytest.fixture() +def mongoengine_datastore(app, + mongoengine_email_datastore, + mongoengine_username_datastore): + if app.config["SECURITY_USER_IDENTITY_ATTRIBUTES"] == ["email"]: + return mongoengine_email_datastore + else: + return mongoengine_username_datastore + + @pytest.fixture() -def sqlalchemy_datastore(request, app, tmpdir): +def sqlalchemy_email_datastore(request, email_app, tmpdir): from flask_sqlalchemy import SQLAlchemy f, path = tempfile.mkstemp( prefix='flask-security-test-db', suffix='.db', dir=str(tmpdir)) - app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + path - db = SQLAlchemy(app) + email_app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + path + db = SQLAlchemy(email_app) roles_users = db.Table( 'roles_users', @@ -197,7 +260,7 @@ class Role(db.Model, RoleMixin): class User(db.Model, UserMixin): id = db.Column(db.Integer, primary_key=True) email = db.Column(db.String(255), unique=True) - username = db.Column(db.String(255)) + username = db.Column(db.String(255)) # TODO: remove password = db.Column(db.String(255)) last_login_at = db.Column(db.DateTime()) current_login_at = db.Column(db.DateTime()) @@ -209,7 +272,7 @@ class User(db.Model, UserMixin): roles = db.relationship('Role', secondary=roles_users, backref=db.backref('users', lazy='dynamic')) - with app.app_context(): + with email_app.app_context(): db.create_all() def tear_down(): @@ -221,7 +284,62 @@ def tear_down(): @pytest.fixture() -def sqlalchemy_session_datastore(request, app, tmpdir): +def sqlalchemy_username_datastore(request, username_app, tmpdir): + from flask_sqlalchemy import SQLAlchemy + + f, path = tempfile.mkstemp( + prefix='flask-security-test-db', suffix='.db', dir=str(tmpdir)) + + username_app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + path + db = SQLAlchemy(username_app) + + roles_users = db.Table( + 'roles_users', + db.Column('user_id', db.Integer(), db.ForeignKey('user.id')), + db.Column('role_id', db.Integer(), db.ForeignKey('role.id'))) + + class Role(db.Model, RoleMixin): + id = db.Column(db.Integer(), primary_key=True) + name = db.Column(db.String(80), unique=True) + description = db.Column(db.String(255)) + + class User(db.Model, UserMixin): + id = db.Column(db.Integer, primary_key=True) + username = db.Column(db.String(255), unique=True) + password = db.Column(db.String(255)) + last_login_at = db.Column(db.DateTime()) + current_login_at = db.Column(db.DateTime()) + last_login_ip = db.Column(db.String(100)) + current_login_ip = db.Column(db.String(100)) + login_count = db.Column(db.Integer) + active = db.Column(db.Boolean()) + confirmed_at = db.Column(db.DateTime()) + roles = db.relationship('Role', secondary=roles_users, + backref=db.backref('users', lazy='dynamic')) + + with username_app.app_context(): + db.create_all() + + def tear_down(): + os.close(f) + os.remove(path) + request.addfinalizer(tear_down) + + return SQLAlchemyUserDatastore(db, User, Role) + + +@pytest.fixture() +def sqlalchemy_datastore(app, + sqlalchemy_email_datastore, + sqlalchemy_username_datastore): + if app.config["SECURITY_USER_IDENTITY_ATTRIBUTES"] == ["email"]: + return sqlalchemy_email_datastore + else: + return sqlalchemy_username_datastore + + +@pytest.fixture() +def sqlalchemy_session_email_datastore(request, email_app, tmpdir): from sqlalchemy import create_engine from sqlalchemy.orm import scoped_session, sessionmaker, relationship, \ backref @@ -232,9 +350,9 @@ def sqlalchemy_session_datastore(request, app, tmpdir): f, path = tempfile.mkstemp( prefix='flask-security-test-db', suffix='.db', dir=str(tmpdir)) - app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + path + email_app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + path - engine = create_engine(app.config['SQLALCHEMY_DATABASE_URI'], + engine = create_engine(email_app.config['SQLALCHEMY_DATABASE_URI'], convert_unicode=True) db_session = scoped_session(sessionmaker(autocommit=False, autoflush=False, @@ -258,7 +376,7 @@ class User(Base, UserMixin): __tablename__ = 'user' id = Column(Integer, primary_key=True) email = Column(String(255), unique=True) - username = Column(String(255)) + username = Column(String(255)) # TODO: remove password = Column(String(255)) last_login_at = Column(DateTime()) current_login_at = Column(DateTime()) @@ -270,7 +388,68 @@ class User(Base, UserMixin): roles = relationship('Role', secondary='roles_users', backref=backref('users', lazy='dynamic')) - with app.app_context(): + with email_app.app_context(): + Base.metadata.create_all(bind=engine) + + def tear_down(): + db_session.close() + os.close(f) + os.remove(path) + request.addfinalizer(tear_down) + + return SQLAlchemySessionUserDatastore(db_session, User, Role) + + +@pytest.fixture() +def sqlalchemy_session_username_datastore(request, username_app, tmpdir): + from sqlalchemy import create_engine + from sqlalchemy.orm import scoped_session, sessionmaker, relationship, \ + backref + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy import Boolean, DateTime, Column, Integer, String, \ + ForeignKey + + f, path = tempfile.mkstemp( + prefix='flask-security-test-db', suffix='.db', dir=str(tmpdir)) + + username_app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + path + + engine = create_engine(username_app.config['SQLALCHEMY_DATABASE_URI'], + convert_unicode=True) + db_session = scoped_session(sessionmaker(autocommit=False, + autoflush=False, + bind=engine)) + Base = declarative_base() + Base.query = db_session.query_property() + + class RolesUsers(Base): + __tablename__ = 'roles_users' + id = Column(Integer(), primary_key=True) + user_id = Column('user_id', Integer(), ForeignKey('user.id')) + role_id = Column('role_id', Integer(), ForeignKey('role.id')) + + class Role(Base, RoleMixin): + __tablename__ = 'role' + id = Column(Integer(), primary_key=True) + name = Column(String(80), unique=True) + description = Column(String(255)) + + class User(Base, UserMixin): + __tablename__ = 'user' + id = Column(Integer, primary_key=True) + username = Column(String(255), unique=True) + password = Column(String(255)) + last_login_at = Column(DateTime()) + current_login_at = Column(DateTime()) + last_login_ip = Column(String(100)) + current_login_ip = Column(String(100)) + login_count = Column(Integer) + active = Column(Boolean()) + confirmed_at = Column(DateTime()) + roles = relationship('Role', secondary='roles_users', + backref=backref('users', lazy='dynamic')) + + with username_app.app_context(): Base.metadata.create_all(bind=engine) def tear_down(): @@ -283,7 +462,17 @@ def tear_down(): @pytest.fixture() -def peewee_datastore(request, app, tmpdir): +def sqlalchemy_session_datastore(app, + sqlalchemy_session_email_datastore, + sqlalchemy_session_username_datastore): + if app.config["SECURITY_USER_IDENTITY_ATTRIBUTES"] == ["email"]: + return sqlalchemy_session_email_datastore + else: + return sqlalchemy_session_username_datastore + + +@pytest.fixture() +def peewee_email_datastore(request, email_app, tmpdir): from peewee import TextField, DateTimeField, IntegerField, BooleanField, \ ForeignKeyField from flask_peewee.db import Database @@ -291,12 +480,12 @@ def peewee_datastore(request, app, tmpdir): f, path = tempfile.mkstemp( prefix='flask-security-test-db', suffix='.db', dir=str(tmpdir)) - app.config['DATABASE'] = { + email_app.config['DATABASE'] = { 'name': path, 'engine': 'peewee.SqliteDatabase' } - db = Database(app) + db = Database(email_app) class Role(db.Model, RoleMixin): name = TextField(unique=True) @@ -304,6 +493,59 @@ class Role(db.Model, RoleMixin): class User(db.Model, UserMixin): email = TextField() + username = TextField(null=True) # TODO: remove + password = TextField(null=True) + last_login_at = DateTimeField(null=True) + current_login_at = DateTimeField(null=True) + last_login_ip = TextField(null=True) + current_login_ip = TextField(null=True) + login_count = IntegerField(null=True) + active = BooleanField(default=True) + confirmed_at = DateTimeField(null=True) + + class UserRoles(db.Model): + """ Peewee does not have built-in many-to-many support, so we have to + create this mapping class to link users to roles.""" + user = ForeignKeyField(User, related_name='roles') + role = ForeignKeyField(Role, related_name='users') + name = property(lambda self: self.role.name) + description = property(lambda self: self.role.description) + + with email_app.app_context(): + for Model in (Role, User, UserRoles): + Model.create_table() + + def tear_down(): + db.close_db(None) + os.close(f) + os.remove(path) + + request.addfinalizer(tear_down) + + return PeeweeUserDatastore(db, User, Role, UserRoles) + + +@pytest.fixture() +def peewee_username_datastore(request, username_app, tmpdir): + from peewee import TextField, DateTimeField, IntegerField, BooleanField, \ + ForeignKeyField + from flask_peewee.db import Database + + f, path = tempfile.mkstemp( + prefix='flask-security-test-db', suffix='.db', dir=str(tmpdir)) + + username_app.config['DATABASE'] = { + 'name': path, + 'engine': 'peewee.SqliteDatabase' + } + + db = Database(username_app) + + class Role(db.Model, RoleMixin): + name = TextField(unique=True) + description = TextField(null=True) + + class User(db.Model, UserMixin): username = TextField() password = TextField(null=True) last_login_at = DateTimeField(null=True) @@ -322,7 +564,7 @@ class UserRoles(db.Model): name = property(lambda self: self.role.name) description = property(lambda self: self.role.description) - with app.app_context(): + with username_app.app_context(): for Model in (Role, User, UserRoles): Model.create_table() @@ -337,7 +579,17 @@ def tear_down(): @pytest.fixture() -def pony_datastore(request, app, tmpdir): +def peewee_datastore(app, + peewee_email_datastore, + peewee_username_datastore): + if app.config["SECURITY_USER_IDENTITY_ATTRIBUTES"] == ["email"]: + return peewee_email_datastore + else: + return peewee_username_datastore + + +@pytest.fixture() +def pony_email_datastore(request, email_app, tmpdir): from pony.orm import Database, Optional, Required, Set from pony.orm.core import SetInstance @@ -351,7 +603,46 @@ class Role(db.Entity): class User(db.Entity): email = Required(str) - username = Optional(str) + username = Optional(str) # TODO: remove + password = Optional(str, nullable=True) + last_login_at = Optional(datetime) + current_login_at = Optional(datetime) + last_login_ip = Optional(str) + current_login_ip = Optional(str) + login_count = Optional(int) + active = Required(bool, default=True) + confirmed_at = Optional(datetime) + roles = Set(lambda: Role) + + def has_role(self, name): + return name in {r.name for r in self.roles.copy()} + + email_app.config['DATABASE'] = { + 'name': ':memory:', + 'engine': 'pony.SqliteDatabase' + } + + db.bind('sqlite', ':memory:', create_db=True) + db.generate_mapping(create_tables=True) + + return PonyUserDatastore(db, User, Role) + + +@pytest.fixture() +def pony_username_datastore(request, username_app, tmpdir): + from pony.orm import Database, Optional, Required, Set + from pony.orm.core import SetInstance + + SetInstance.append = SetInstance.add + db = Database() + + class Role(db.Entity): + name = Required(str, unique=True) + description = Optional(str, nullable=True) + users = Set(lambda: User) + + class User(db.Entity): + username = Required(str) password = Optional(str, nullable=True) last_login_at = Optional(datetime) current_login_at = Optional(datetime) @@ -365,7 +656,7 @@ class User(db.Entity): def has_role(self, name): return name in {r.name for r in self.roles.copy()} - app.config['DATABASE'] = { + username_app.config['DATABASE'] = { 'name': ':memory:', 'engine': 'pony.SqliteDatabase' } @@ -376,6 +667,16 @@ def has_role(self, name): return PonyUserDatastore(db, User, Role) +@pytest.fixture() +def pony_datastore(app, + pony_email_datastore, + pony_username_datastore): + if app.config["SECURITY_USER_IDENTITY_ATTRIBUTES"] == ["email"]: + return pony_email_datastore + else: + return pony_username_datastore + + @pytest.fixture() def sqlalchemy_app(app, sqlalchemy_datastore): def create(): @@ -440,37 +741,92 @@ def fn(key, **kwargs): @pytest.fixture(params=['sqlalchemy', 'sqlalchemy-session', 'mongoengine', 'peewee', 'pony']) -def datastore( - request, - sqlalchemy_datastore, - sqlalchemy_session_datastore, - mongoengine_datastore, - peewee_datastore, - pony_datastore): - if request.param == 'sqlalchemy': - rv = sqlalchemy_datastore - elif request.param == 'sqlalchemy-session': - rv = sqlalchemy_session_datastore - elif request.param == 'mongoengine': - rv = mongoengine_datastore - elif request.param == 'peewee': - rv = peewee_datastore - elif request.param == 'pony': - rv = pony_datastore +def datastore_backend_name(request): + return request.param + + +@pytest.fixture() +def email_datastore( + datastore_backend_name, + sqlalchemy_email_datastore, + sqlalchemy_session_email_datastore, + mongoengine_email_datastore, + peewee_email_datastore, + pony_email_datastore): + if datastore_backend_name == 'sqlalchemy': + rv = sqlalchemy_email_datastore + elif datastore_backend_name == 'sqlalchemy-session': + rv = sqlalchemy_session_email_datastore + elif datastore_backend_name == 'mongoengine': + rv = mongoengine_email_datastore + elif datastore_backend_name == 'peewee': + rv = peewee_email_datastore + elif datastore_backend_name == 'pony': + rv = pony_email_datastore + return rv + + +@pytest.fixture() +def username_datastore( + datastore_backend_name, + sqlalchemy_username_datastore, + sqlalchemy_session_username_datastore, + mongoengine_username_datastore, + peewee_username_datastore, + pony_username_datastore): + if datastore_backend_name == 'sqlalchemy': + rv = sqlalchemy_username_datastore + elif datastore_backend_name == 'sqlalchemy-session': + rv = sqlalchemy_session_username_datastore + elif datastore_backend_name == 'mongoengine': + rv = mongoengine_username_datastore + elif datastore_backend_name == 'peewee': + rv = peewee_username_datastore + elif datastore_backend_name == 'pony': + rv = pony_username_datastore return rv @pytest.fixture() -def script_info(app, datastore): +def datastore(app, + email_datastore, + username_datastore): + if app.config["SECURITY_USER_IDENTITY_ATTRIBUTES"] == ["email"]: + return email_datastore + else: + return username_datastore + + +@pytest.fixture() +def email_script_info(email_app, email_datastore): try: from flask.cli import ScriptInfo except ImportError: from flask_cli import ScriptInfo def create_app(info): - app.config.update(**{ - 'SECURITY_USER_IDENTITY_ATTRIBUTES': ('email', 'username') - }) - app.security = Security(app, datastore=datastore) - return app + email_app.security = Security(email_app, datastore=email_datastore) + return email_app + return ScriptInfo(create_app=create_app) + + +@pytest.fixture() +def username_script_info(username_app, username_datastore): + try: + from flask.cli import ScriptInfo + except ImportError: + from flask_cli import ScriptInfo + + def create_app(info): + username_app.security = Security(username_app, + datastore=username_datastore) + return username_app return ScriptInfo(create_app=create_app) + + +@pytest.fixture() +def script_info(app, email_script_info, username_script_info): + if app.config["SECURITY_USER_IDENTITY_ATTRIBUTES"] == ["email"]: + return email_script_info + else: + return username_script_info diff --git a/tests/templates/_current_user_email.html b/tests/templates/_current_user_email.html new file mode 100644 index 00000000..fc4cc55d --- /dev/null +++ b/tests/templates/_current_user_email.html @@ -0,0 +1,3 @@ +{%- if current_user.is_authenticated -%} +

Hello {{ current_user.email }}

+{%- endif %} diff --git a/tests/templates/_current_user_username.html b/tests/templates/_current_user_username.html new file mode 100644 index 00000000..5fdc9c75 --- /dev/null +++ b/tests/templates/_current_user_username.html @@ -0,0 +1,3 @@ +{%- if current_user.is_authenticated -%} +

Hello {{ current_user.username }}

+{%- endif %} diff --git a/tests/templates/_nav.html b/tests/templates/_nav.html index 53dc2668..41888460 100644 --- a/tests/templates/_nav.html +++ b/tests/templates/_nav.html @@ -1,6 +1,8 @@ -{%- if current_user.is_authenticated -%} -

Hello {{ current_user.email }}

-{%- endif %} +{%- if config.SECURITY_USER_IDENTITY_ATTRIBUTES == ["email"] -%} + {% include "_current_user_email.html" %} +{%- else -%} + {% include "_current_user_username.html" %} +{%- endif -%} \ No newline at end of file + diff --git a/tests/templates/register.html b/tests/templates/register.html index 7f4e27ac..23cf6cb3 100644 --- a/tests/templates/register.html +++ b/tests/templates/register.html @@ -3,7 +3,7 @@

Register

{{ register_user_form.hidden_tag() }} - {{ register_user_form.email.label }} {{ register_user_form.email }}
+ {{ register_user_form.identity.label }} {{ register_user_form.identity }}
{{ register_user_form.password.label }} {{ register_user_form.password }}
{{ register_user_form.password_confirm.label }} {{ register_user_form.password_confirm }}
{{ register_user_form.submit }} diff --git a/tests/test_cli.py b/tests/test_cli.py index 19a1962d..2506c8c8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -121,6 +121,72 @@ def test_cli_addremove_role(script_info): assert result.exit_code == 0 +def test_cli_addremove_role_username(username_script_info): + """Test add/remove role, for username-only model""" + runner = CliRunner() + script_info = username_script_info + + # Create a user and a role + result = runner.invoke( + users_create, + ['alice', '--password', '123456'], + obj=script_info + ) + assert result.exit_code == 0 + result = runner.invoke(roles_create, ['superuser'], obj=script_info) + assert result.exit_code == 0 + + # User not found + result = runner.invoke( + roles_add, ['inval', 'superuser'], + obj=script_info) + assert result.exit_code != 0 + + # Add: + result = runner.invoke( + roles_add, ['alice', 'invalid'], + obj=script_info) + assert result.exit_code != 0 + + result = runner.invoke( + roles_remove, ['inval', 'superuser'], + obj=script_info) + assert result.exit_code != 0 + + # Remove: + result = runner.invoke( + roles_remove, ['alice', 'invalid'], + obj=script_info) + assert result.exit_code != 0 + + result = runner.invoke( + roles_remove, ['bob', 'superuser'], + obj=script_info) + assert result.exit_code != 0 + + result = runner.invoke( + roles_remove, ['alice', 'superuser'], + obj=script_info) + assert result.exit_code != 0 + + # Add: + result = runner.invoke(roles_add, + ['alice', 'superuser'], + obj=script_info) + assert result.exit_code == 0 + result = runner.invoke( + roles_add, + ['alice', 'superuser'], + obj=script_info) + assert result.exit_code != 0 + + # Remove: + result = runner.invoke( + roles_remove, ['alice', 'superuser'], + obj=script_info) + assert result.exit_code == 0 + + def test_cli_activate_deactivate(script_info): """Test create user CLI.""" runner = CliRunner() diff --git a/tests/test_datastore.py b/tests/test_datastore.py index 218923eb..26b5f386 100644 --- a/tests/test_datastore.py +++ b/tests/test_datastore.py @@ -81,10 +81,10 @@ def test_activate_returns_false_if_already_true(): assert not datastore.activate_user(user) -def test_get_user(app, datastore): - init_app_with_options(app, datastore, **{ - 'SECURITY_USER_IDENTITY_ATTRIBUTES': ('email', 'username') - }) +def test_get_user_by_email(email_app, email_datastore): + app = email_app + datastore = email_datastore + init_app_with_options(app, datastore) with app.app_context(): user_id = datastore.find_user(email='matt@lp.com').id @@ -95,14 +95,26 @@ def test_get_user(app, datastore): user = datastore.get_user('matt@lp.com') assert user is not None - user = datastore.get_user('matt') - assert user is not None - # Regression check user = datastore.get_user('%lp.com') assert user is None +def test_get_user_by_username(username_app, username_datastore): + app = username_app + datastore = username_datastore + init_app_with_options(app, datastore) + + with app.app_context(): + user_id = datastore.find_user(username='matt').id + + user = datastore.get_user(user_id) + assert user is not None + + user = datastore.get_user('matt') + assert user is not None + + def test_find_role(app, datastore): init_app_with_options(app, datastore) diff --git a/tests/test_misc.py b/tests/test_misc.py index 81dc6fce..14145c79 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -21,15 +21,15 @@ @pytest.mark.recoverable() -def test_async_email_task(app, client): - app.mail_sent = False +def test_async_email_task(email_app, client): + email_app.mail_sent = False - @app.security.send_mail_task + @email_app.security.send_mail_task def send_email(msg): - app.mail_sent = True + email_app.mail_sent = True client.post('/reset', data=dict(email='matt@lp.com')) - assert app.mail_sent is True + assert email_app.mail_sent is True def test_register_blueprint_flag(app, sqlalchemy_datastore): @@ -138,13 +138,12 @@ class MyPasswordlessLoginForm(PasswordlessLoginForm): assert b'My Passwordless Email Address Field' in response.data -def test_addition_identity_attributes(app, sqlalchemy_datastore): - init_app_with_options(app, sqlalchemy_datastore, **{ - 'SECURITY_USER_IDENTITY_ATTRIBUTES': ('email', 'username') - }) - client = app.test_client() - response = authenticate(client, email='matt', follow_redirects=True) - assert b'Hello matt@lp.com' in response.data +def test_username_identity_attribute(username_app, + sqlalchemy_username_datastore): + init_app_with_options(username_app, sqlalchemy_username_datastore) + client = username_app.test_client() + response = authenticate(client, username='matt', follow_redirects=True) + assert b'Hello matt' in response.data def test_flash_messages_off(app, sqlalchemy_datastore, get_message): @@ -234,7 +233,7 @@ def test_password_unicode_password_salt(client): response = authenticate(client) assert response.status_code == 302 response = authenticate(client, follow_redirects=True) - assert b'Hello matt@lp.com' in response.data + assert b'Hello matt' in response.data def test_set_unauthorized_handler(app, client): diff --git a/tests/test_recoverable.py b/tests/test_recoverable.py index a741d875..27c32440 100644 --- a/tests/test_recoverable.py +++ b/tests/test_recoverable.py @@ -21,7 +21,8 @@ pytestmark = pytest.mark.recoverable() -def test_recoverable_flag(app, client, get_message): +def test_recoverable_flag(email_app, client, get_message): + app = email_app recorded_resets = [] recorded_instructions_sent = [] @@ -75,8 +76,8 @@ def on_instructions_sent(app, user, token): # Test logging in with the new password response = authenticate( client, - 'joe@lp.com', - 'newpassword', + email='joe@lp.com', + password='newpassword', follow_redirects=True) assert b'Hello joe@lp.com' in response.data diff --git a/tests/utils.py b/tests/utils.py index 8d7710a3..af29a404 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -18,10 +18,18 @@ def authenticate( client, email="matt@lp.com", + username="matt", password="password", endpoint=None, **kwargs): - data = dict(email=email, password=password, remember='y') + identity_attrs = ( + client.application.config["SECURITY_USER_IDENTITY_ATTRIBUTES"] + ) + data = dict(password=password, remember='y') + if identity_attrs == ["email"]: + data["email"] = email + else: + data["username"] = username return client.post(endpoint or '/login', data=data, **kwargs) @@ -47,7 +55,7 @@ def create_roles(ds): ds.commit() -def create_users(ds, count=None): +def create_users_email(ds, count=None): users = [('matt@lp.com', 'matt', 'password', ['admin'], True), ('joe@lp.com', 'joe', 'password', ['editor'], True), ('dave@lp.com', 'dave', 'password', ['admin', 'editor'], True), @@ -73,11 +81,40 @@ def create_users(ds, count=None): ds.commit() +def create_users_username(ds, count=None): + users = [('matt', 'password', ['admin'], True), + ('joe', 'password', ['editor'], True), + ('dave', 'password', ['admin', 'editor'], True), + ('jill', 'password', ['author'], True), + ('tiya', 'password', [], False), + ('jess', None, [], True)] + count = count or len(users) + + for u in users[:count]: + pw = u[1] + if pw is not None: + pw = encrypt_password(pw) + roles = [ds.find_or_create_role(rn) for rn in u[2]] + ds.commit() + user = ds.create_user( + username=u[0], + password=pw, + active=u[3]) + ds.commit() + for role in roles: + ds.add_role_to_user(user, role) + ds.commit() + + def populate_data(app, user_count=None): ds = app.security.datastore + identity_attrs = app.config["SECURITY_USER_IDENTITY_ATTRIBUTES"] with app.app_context(): create_roles(ds) - create_users(ds, user_count) + if identity_attrs == ["email"]: + create_users_email(ds, user_count) + else: + create_users_username(ds, user_count) class Response(BaseResponse): # pragma: no cover