]> littlesong.place Git - littlesongplace.git/commitdiff
Allow password hashes as bytes or strings in DB dev master
authorChris Fulljames <christianfulljames@gmail.com>
Sat, 2 May 2026 15:59:19 +0000 (11:59 -0400)
committerChris Fulljames <christianfulljames@gmail.com>
Sat, 2 May 2026 15:59:19 +0000 (11:59 -0400)
src/littlesongplace/auth.py

index 21b15afd3c1b53d986b963321dede4aec6aacb3a..02dcbc3102659d0d7464876d951bb2f4aff76a59 100644 (file)
@@ -13,7 +13,7 @@ bp = Blueprint("auth", __name__)
 def signup_get():
     return render_template("signup.html")
 
-def _check_password(password, password_confirm):
+def _validate_password(password, password_confirm):
     error = False
     if password != password_confirm:
         flash_and_log("Passwords do not match", "error")
@@ -23,6 +23,17 @@ def _check_password(password, password_confirm):
         error = True
     return error
 
+def _check_password(user_data, password):
+    if not user_data:
+        return False
+
+    # Password has must be bytes
+    pwhash = user_data["password"]
+    if isinstance(pwhash, str):
+        pwhash = pwhash.encode()
+
+    return bcrypt.checkpw(password.encode(), pwhash)
+
 def _hash_password(password):
     return bcrypt.hashpw(password.encode(), bcrypt.gensalt())
 
@@ -43,7 +54,7 @@ def signup_post():
         flash_and_log("Username cannot be more than 30 characters", "error")
         error = True
 
-    error = error or _check_password(password, password_confirm)
+    error = error or _validate_password(password, password_confirm)
 
     if db.query("select * from users where username = ?", [username], one=True):
         flash_and_log(f"Username '{username}' is already taken", "error")
@@ -86,7 +97,7 @@ def login_post():
 
     user_data = db.query("select * from users where username = ?", [username], one=True)
 
-    if user_data and bcrypt.checkpw(password.encode(), user_data["password"]):
+    if _check_password(user_data, password):
         # Successful login
         session["username"] = username
         session["userid"] = user_data["userid"]
@@ -122,12 +133,12 @@ def password_reset_post():
         error = True
         
     # Check old password
-    elif not bcrypt.checkpw(old_password.encode(), user_data["password"]):
+    if not _check_password(user_data, old_password):
         flash("Invalid username/password", "error")
         error = True
 
     # Check new password
-    error = error or _check_password(password, password_confirm)
+    error = error or _validate_password(password, password_confirm)
 
     # Reload page on error
     if error: