]> littlesong.place Git - littlesongplace.git/commitdiff
Start refactoring __init__
authorChris Fulljames <christianfulljames@gmail.com>
Sat, 5 Apr 2025 14:24:18 +0000 (10:24 -0400)
committerChris Fulljames <christianfulljames@gmail.com>
Sat, 5 Apr 2025 14:24:18 +0000 (10:24 -0400)
src/littlesongplace/__init__.py
src/littlesongplace/auth.py [new file with mode: 0644]
src/littlesongplace/comments.py [new file with mode: 0644]
src/littlesongplace/datadir.py [new file with mode: 0644]
src/littlesongplace/db.py [new file with mode: 0644]
src/littlesongplace/logutils.py [new file with mode: 0644]
test/conftest.py

index 8e62746ed60ec52ff624c0bec6bb8378ba2242d3..5c3ff73924a4e5300083ef22692fadc811caac73 100644 (file)
@@ -5,7 +5,6 @@ import logging
 import os
 import random
 import shutil
-import sqlite3
 import subprocess
 import sys
 import tempfile
@@ -16,7 +15,6 @@ from logging.handlers import RotatingFileHandler
 from pathlib import Path, PosixPath
 from typing import Optional
 
-import bcrypt
 import bleach
 import click
 from bleach.css_sanitizer import CSSSanitizer
@@ -28,12 +26,8 @@ from werkzeug.middleware.proxy_fix import ProxyFix
 from yt_dlp import YoutubeDL
 from yt_dlp.utils import DownloadError
 
-DB_VERSION = 4
-SCRIPT_DIR = Path(__file__).parent
-DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else Path(".data").absolute()
-
-# Make sure DATA_DIR exists
-os.makedirs(DATA_DIR, exist_ok=True)
+from . import auth, comments, datadir, db
+from .logutils import flash_and_log
 
 BGCOLOR = "#e8e6b5"
 FGCOLOR = "#695c73"
@@ -44,7 +38,7 @@ DEFAULT_COLORS = dict(bgcolor=BGCOLOR, fgcolor=FGCOLOR, accolor=ACCOLOR)
 # Logging
 ################################################################################
 
-handler = RotatingFileHandler(DATA_DIR / "app.log", maxBytes=1_000_000, backupCount=10)
+handler = RotatingFileHandler(datadir.get_app_log_path(), maxBytes=1_000_000, backupCount=10)
 handler.setLevel(logging.INFO)
 handler.setFormatter(logging.Formatter('[%(asctime)s] %(levelname)s in %(module)s: %(message)s'))
 
@@ -58,6 +52,8 @@ root_logger.addHandler(handler)
 app = Flask(__name__)
 app.secret_key = os.environ["SECRET_KEY"] if "SECRET_KEY" in os.environ else "dev"
 app.config["MAX_CONTENT_LENGTH"] = 1 * 1024 * 1024 * 1024
+app.register_blueprint(auth.bp)
+db.init_app(app)
 
 if "DATA_DIR" in os.environ:
     # Running on server behind proxy
@@ -68,7 +64,7 @@ if "DATA_DIR" in os.environ:
 
 @app.route("/")
 def index():
-    users = query_db("select * from users order by username asc")
+    users = db.query("select * from users order by username asc")
     users = [dict(row) for row in users]
     for user in users:
         user["has_pfp"] = user_has_pfp(user["userid"])
@@ -87,97 +83,11 @@ def index():
     songs = Song.get_latest(50)
     return render_template("index.html", users=users, songs=songs, page_title=title)
 
-@app.get("/signup")
-def signup_get():
-    return render_template("signup.html")
-
-@app.post("/signup")
-def signup_post():
-    username = request.form["username"]
-    password = request.form["password"]
-    password_confirm = request.form["password_confirm"]
-
-    error = False
-    if not username.isidentifier():
-        flash_and_log("Username cannot contain special characters", "error")
-        error = True
-    elif len(username) < 3:
-        flash_and_log("Username must be at least 3 characters", "error")
-        error = True
-    elif len(username) > 30:
-        flash_and_log("Username cannot be more than 30 characters", "error")
-        error = True
-
-    elif password != password_confirm:
-        flash_and_log("Passwords do not match", "error")
-        error = True
-    elif len(password) < 8:
-        flash_and_log("Password must be at least 8 characters", "error")
-        error = True
-
-    if query_db("select * from users where username = ?", [username], one=True):
-        flash_and_log(f"Username '{username}' is already taken", "error")
-        error = True
-
-    if error:
-        app.logger.info("Failed signup attempt")
-        return redirect(request.referrer)
-
-    password = bcrypt.hashpw(password.encode(), bcrypt.gensalt())
-    timestamp = datetime.now(timezone.utc).isoformat()
-
-    user_data = query_db("insert into users (username, password, created) values (?, ?, ?) returning userid", [username, password, timestamp], one=True)
-
-    # Create profile comment thread
-    threadid = create_comment_thread(ThreadType.PROFILE, user_data["userid"])
-    query_db("update users set threadid = ? where userid = ?", [threadid, user_data["userid"]])
-    get_db().commit()
-
-    flash("User created.  Please sign in to continue.", "success")
-    app.logger.info(f"Created user {username}")
-
-    return redirect("/login")
-
-@app.get("/login")
-def login_get():
-    return render_template("login.html")
-
-@app.post("/login")
-def login_post():
-    username = request.form["username"]
-    password = request.form["password"]
-
-    user_data = query_db("select * from users where username = ?", [username], one=True)
-
-    if user_data and bcrypt.checkpw(password.encode(), user_data["password"]):
-        # Successful login
-        session["username"] = username
-        session["userid"] = user_data["userid"]
-        session.permanent = True
-        app.logger.info(f"{username} logged in")
-
-        return redirect(f"/users/{username}")
-
-    flash("Invalid username/password", "error")
-    app.logger.info(f"Failed login for {username}")
-
-    return render_template("login.html")
-
-
-@app.get("/logout")
-def logout():
-    if "username" in session:
-        session.pop("username")
-    if "userid" in session:
-        session.pop("userid")
-
-    return redirect("/")
-
 @app.get("/users/<profile_username>")
 def users_profile(profile_username):
 
     # Look up user data for current profile
-    profile_data = query_db("select * from users where username = ?", [profile_username], one=True)
+    profile_data = db.query("select * from users where username = ?", [profile_username], one=True)
     if profile_data is None:
         abort(404)
     profile_userid = profile_data["userid"]
@@ -186,15 +96,15 @@ def users_profile(profile_username):
     userid = session.get("userid", None)
     show_private = userid == profile_userid
     if show_private:
-        plist_data = query_db("select * from playlists where userid = ? order by updated desc", [profile_userid])
+        plist_data = db.query("select * from playlists where userid = ? order by updated desc", [profile_userid])
     else:
-        plist_data = query_db("select * from playlists where userid = ? and private = 0 order by updated desc", [profile_userid])
+        plist_data = db.query("select * from playlists where userid = ? and private = 0 order by updated desc", [profile_userid])
 
     # Get songs for current profile
     songs = Song.get_all_for_userid(profile_userid)
 
     # Get comments for current profile
-    comments = get_comments(profile_data["threadid"])
+    profile_comments = get_comments(profile_data["threadid"])
 
     # Sanitize bio
     profile_bio = ""
@@ -209,7 +119,7 @@ def users_profile(profile_username):
             **get_user_colors(profile_data),
             playlists=plist_data,
             songs=songs,
-            comments=comments,
+            comments=profile_comments,
             threadid=profile_data["threadid"],
             user_has_pfp=user_has_pfp(profile_userid))
 
@@ -218,13 +128,13 @@ def edit_profile():
     if not "userid" in session:
         abort(401)
 
-    query_db(
+    db.query(
             "update users set bio = ?, bgcolor = ?, fgcolor = ?, accolor = ? where userid = ?",
             [request.form["bio"], request.form["bgcolor"], request.form["fgcolor"], request.form["accolor"], session["userid"]])
-    get_db().commit()
+    db.commit()
 
     if request.files["pfp"]:
-        pfp_path = get_user_images_path(session["userid"]) / "pfp.jpg"
+        pfp_path = datadir.get_user_images_path(session["userid"]) / "pfp.jpg"
 
         try:
             with Image.open(request.files["pfp"]) as im:
@@ -263,7 +173,7 @@ def edit_profile():
 
 @app.get("/pfp/<int:userid>")
 def pfp(userid):
-    return send_from_directory(get_user_images_path(userid), "pfp.jpg")
+    return send_from_directory(datadir.get_user_images_path(userid), "pfp.jpg")
 
 @app.get("/edit-song")
 def edit_song():
@@ -362,18 +272,6 @@ def validate_song_form():
 
     return error
 
-def get_user_songs_path(userid):
-    userpath = DATA_DIR / "songs" / str(userid)
-    if not userpath.exists():
-        os.makedirs(userpath)
-    return userpath
-
-def get_user_images_path(userid):
-    userpath = DATA_DIR / "images" / str(userid)
-    if not userpath.exists():
-        os.makedirs(userpath)
-    return userpath
-
 def update_song():
     songid = request.args["songid"]
     try:
@@ -389,7 +287,7 @@ def update_song():
     collaborators = [c.strip() for c in request.form["collabs"].split(",") if c]
 
     # Make sure song exists and the logged-in user owns it
-    song_data = query_db("select * from songs where songid = ?", [songid], one=True)
+    song_data = db.query("select * from songs where songid = ?", [songid], one=True)
     if song_data is None:
         abort(400)
     elif session["userid"] != song_data["userid"]:
@@ -402,28 +300,28 @@ def update_song():
 
             if passed:
                 # Move file to permanent location
-                filepath = get_user_songs_path(session["userid"]) / (str(song_data["songid"]) + ".mp3")
+                filepath = datadir.get_user_songs_path(session["userid"]) / (str(song_data["songid"]) + ".mp3")
                 shutil.move(tmp_file.name, filepath)
             else:
                 error = True
 
     if not error:
         # Update songs table
-        query_db(
+        db.query(
                 "update songs set title = ?, description = ? where songid = ?",
                 [title, description, songid])
 
         # Update song_tags table
-        query_db("delete from song_tags where songid = ?", [songid])
+        db.query("delete from song_tags where songid = ?", [songid])
         for tag in tags:
-            query_db("insert into song_tags (tag, songid) values (?, ?)", [tag, songid])
+            db.query("insert into song_tags (tag, songid) values (?, ?)", [tag, songid])
 
         # Update song_collaborators table
-        query_db("delete from song_collaborators where songid = ?", [songid])
+        db.query("delete from song_collaborators where songid = ?", [songid])
         for collab in collaborators:
-            query_db("insert into song_collaborators (name, songid) values (?, ?)", [collab, songid])
+            db.query("insert into song_collaborators (name, songid) values (?, ?)", [collab, songid])
 
-        get_db().commit()
+        db.commit()
         flash_and_log(f"Successfully updated '{title}'", "success")
 
     return error
@@ -443,27 +341,27 @@ def create_song():
             return True
         else:
             # Create comment thread
-            threadid = create_comment_thread(ThreadType.SONG, session["userid"])
+            threadid = comments.create_thread(comments.ThreadType.SONG, session["userid"])
             # Create song
             timestamp = datetime.now(timezone.utc).isoformat()
-            song_data = query_db(
+            song_data = db.query(
                     "insert into songs (userid, title, description, created, threadid) values (?, ?, ?, ?, ?) returning (songid)",
                     [session["userid"], title, description, timestamp, threadid], one=True)
             songid = song_data["songid"]
-            filepath = get_user_songs_path(session["userid"]) / (str(song_data["songid"]) + ".mp3")
+            filepath = datadir.get_user_songs_path(session["userid"]) / (str(song_data["songid"]) + ".mp3")
 
             # Move file to permanent location
             shutil.move(tmp_file.name, filepath)
 
             # Assign tags
             for tag in tags:
-                query_db("insert into song_tags (tag, songid) values (?, ?)", [tag, songid])
+                db.query("insert into song_tags (tag, songid) values (?, ?)", [tag, songid])
 
             # Assign collaborators
             for collab in collaborators:
-                query_db("insert into song_collaborators (songid, name) values (?, ?)", [songid, collab])
+                db.query("insert into song_collaborators (songid, name) values (?, ?)", [songid, collab])
 
-            get_db().commit()
+            db.commit()
 
             flash_and_log(f"Successfully uploaded '{title}'", "success")
             return False
@@ -521,7 +419,7 @@ def yt_import(tmp_file, yt_url):
 @app.get("/delete-song/<int:songid>")
 def delete_song(songid):
 
-    song_data = query_db("select * from songs where songid = ?", [songid], one=True)
+    song_data = db.query("select * from songs where songid = ?", [songid], one=True)
 
     if not song_data:
         app.logger.warning(f"Failed song delete - {session['username']} - song doesn't exist")
@@ -533,15 +431,15 @@ def delete_song(songid):
         abort(401)
 
     # Delete tags, collaborators
-    query_db("delete from song_tags where songid = ?", [songid])
-    query_db("delete from song_collaborators where songid = ?", [songid])
+    db.query("delete from song_tags where songid = ?", [songid])
+    db.query("delete from song_collaborators where songid = ?", [songid])
 
     # Delete song database entry
-    query_db("delete from songs where songid = ?", [songid])
-    get_db().commit()
+    db.query("delete from songs where songid = ?", [songid])
+    db.commit()
 
     # Delete song file from disk
-    songpath = DATA_DIR / "songs" / str(session["userid"]) / (str(songid) + ".mp3")
+    songpath = datadir.get_user_songs_path(session["userid"]) / (str(songid) + ".mp3")
     if songpath.exists():
         os.remove(songpath)
 
@@ -566,7 +464,7 @@ def song(userid, songid):
         except ValueError:
             abort(404)
     else:
-        return send_from_directory(DATA_DIR / "songs" / str(userid), str(songid) + ".mp3")
+        return send_from_directory(datadir.get_user_songs_path(userid), str(songid) + ".mp3")
 
 @app.get("/songs")
 def songs():
@@ -596,7 +494,7 @@ def comment():
     if not "threadid" in request.args:
         abort(400) # Must have threadid
 
-    thread = query_db("select * from comment_threads where threadid = ?", [request.args["threadid"]], one=True)
+    thread = db.query("select * from comment_threads where threadid = ?", [request.args["threadid"]], one=True)
     if not thread:
         abort(404) # Invalid threadid
 
@@ -604,7 +502,7 @@ def comment():
     replyto = None
     if "replytoid" in request.args:
         replytoid = request.args["replytoid"]
-        replyto = query_db("select * from comments inner join users on comments.userid == users.userid where commentid = ?", [replytoid], one=True)
+        replyto = db.query("select * from comments inner join users on comments.userid == users.userid where commentid = ?", [replytoid], one=True)
         if not replyto:
             abort(404) # Invalid comment
 
@@ -612,7 +510,7 @@ def comment():
     comment = None
     if "commentid" in request.args:
         commentid = request.args["commentid"]
-        comment = query_db("select * from comments inner join users on comments.userid == users.userid where commentid = ?", [commentid], one=True)
+        comment = db.query("select * from comments inner join users on comments.userid == users.userid where commentid = ?", [commentid], one=True)
         if not comment:
             abort(404) # Invalid comment
         if comment["userid"] != session["userid"]:
@@ -625,12 +523,12 @@ def comment():
         song = None
         profile = None
         playlist = None
-        if threadtype == ThreadType.SONG:
+        if threadtype == comments.ThreadType.SONG:
             song = Song.by_threadid(request.args["threadid"])
-        elif threadtype == ThreadType.PROFILE:
-            profile = query_db("select * from users where threadid = ?", [request.args["threadid"]], one=True)
-        elif threadtype == ThreadType.PLAYLIST:
-            profile = query_db("select * from playlists inner join users on playlists.userid = users.userid where playlists.threadid = ?", [request.args["threadid"]], one=True)
+        elif threadtype == comments.ThreadType.PROFILE:
+            profile = db.query("select * from users where threadid = ?", [request.args["threadid"]], one=True)
+        elif threadtype == comments.ThreadType.PLAYLIST:
+            profile = db.query("select * from playlists inner join users on playlists.userid = users.userid where playlists.threadid = ?", [request.args["threadid"]], one=True)
         return render_template(
             "comment.html",
             song=song,
@@ -645,7 +543,7 @@ def comment():
         content = request.form["content"]
         if comment:
             # Update existing comment
-            query_db("update comments set content = ? where commentid = ?", args=[content, comment["commentid"]])
+            db.query("update comments set content = ? where commentid = ?", args=[content, comment["commentid"]])
         else:
             # Add new comment
             timestamp = datetime.now(timezone.utc).isoformat()
@@ -653,7 +551,7 @@ def comment():
             replytoid = request.args.get("replytoid", None)
 
             threadid = request.args["threadid"]
-            comment = query_db(
+            comment = db.query(
                     "insert into comments (threadid, userid, replytoid, created, content) values (?, ?, ?, ?, ?) returning (commentid)",
                     args=[threadid, userid, replytoid, timestamp, content], one=True)
             commentid = comment["commentid"]
@@ -665,7 +563,7 @@ def comment():
                 notification_targets.add(replyto["userid"])
 
                 # Notify previous repliers in thread
-                previous_replies = query_db("select * from comments where replytoid = ?", [replytoid])
+                previous_replies = db.query("select * from comments where replytoid = ?", [replytoid])
                 for reply in previous_replies:
                     notification_targets.add(reply["userid"])
 
@@ -675,9 +573,9 @@ def comment():
 
             # Create notifications
             for target in notification_targets:
-                query_db("insert into notifications (objectid, objecttype, targetuserid, created) values (?, ?, ?, ?)", [commentid, ObjectType.COMMENT, target, timestamp])
+                db.query("insert into notifications (objectid, objecttype, targetuserid, created) values (?, ?, ?, ?)", [commentid, ObjectType.COMMENT, target, timestamp])
 
-        get_db().commit()
+        db.commit()
 
         return redirect_to_previous_page()
 
@@ -693,7 +591,7 @@ def comment_delete(commentid):
     if "userid" not in session:
         return redirect("/login")
 
-    comment = query_db("select c.userid as comment_user, t.userid as thread_user from comments as c inner join comment_threads as t on c.threadid == t.threadid where commentid = ?", [commentid], one=True)
+    comment = db.query("select c.userid as comment_user, t.userid as thread_user from comments as c inner join comment_threads as t on c.threadid == t.threadid where commentid = ?", [commentid], one=True)
     if not comment:
         abort(404) # Invalid comment
 
@@ -702,8 +600,8 @@ def comment_delete(commentid):
             or (comment["thread_user"] == session["userid"])):
         abort(403)
 
-    query_db("delete from comments where (commentid = ?) or (replytoid = ?)", [commentid, commentid])
-    get_db().commit()
+    db.query("delete from comments where (commentid = ?) or (replytoid = ?)", [commentid, commentid])
+    db.commit()
 
     return redirect(request.referrer)
 
@@ -713,7 +611,7 @@ def activity():
         return redirect("/login")
 
     # Get comment notifications
-    comments = query_db(
+    notifications = db.query(
         """\
         select c.content, c.commentid, c.replytoid, cu.username as comment_username, rc.content as replyto_content, c.threadid, t.threadtype
         from notifications as n
@@ -726,21 +624,21 @@ def activity():
         """,
         [session["userid"], ObjectType.COMMENT])
 
-    comments = [dict(c) for c in comments]
-    for comment in comments:
+    notifications = [dict(c) for c in notifications]
+    for comment in notifications:
         threadtype = comment["threadtype"]
-        if threadtype == ThreadType.SONG:
+        if threadtype == comments.ThreadType.SONG:
             song = Song.by_threadid(comment["threadid"])
             comment["songid"] = song.songid
             comment["title"] = song.title
             comment["content_userid"] = song.userid
             comment["content_username"] = song.username
-        elif threadtype == ThreadType.PROFILE:
-            profile = query_db("select * from users where threadid = ?", [comment["threadid"]], one=True)
+        elif threadtype == comments.ThreadType.PROFILE:
+            profile = db.query("select * from users where threadid = ?", [comment["threadid"]], one=True)
             comment["content_userid"] = profile["userid"]
             comment["content_username"] = profile["username"]
-        elif threadtype == ThreadType.PLAYLIST:
-            playlist = query_db(
+        elif threadtype == comments.ThreadType.PLAYLIST:
+            playlist = db.query(
                 """\
                 select * from playlists
                 inner join users on playlists.userid == users.userid
@@ -755,17 +653,17 @@ def activity():
             comment["content_username"] = playlist["username"]
 
     timestamp = datetime.now(timezone.utc).isoformat()
-    query_db("update users set activitytime = ? where userid = ?", [timestamp, session["userid"]])
-    get_db().commit()
+    db.query("update users set activitytime = ? where userid = ?", [timestamp, session["userid"]])
+    db.commit()
 
-    return render_template("activity.html", comments=comments)
+    return render_template("activity.html", comments=notifications)
 
 @app.get("/new-activity")
 def new_activity():
     has_new_activity = False
     if "userid" in session:
-        user_data = query_db("select activitytime from users where userid = ?", [session["userid"]], one=True)
-        comment_data = query_db(
+        user_data = db.query("select activitytime from users where userid = ?", [session["userid"]], one=True)
+        comment_data = db.query(
             """\
             select created from notifications
             where targetuserid = ?
@@ -801,9 +699,9 @@ def create_playlist():
 
     private = request.form["type"] == "private"
 
-    threadid = create_comment_thread(ThreadType.PLAYLIST, session["userid"])
+    threadid = comments.create_thread(comments.ThreadType.PLAYLIST, session["userid"])
 
-    query_db(
+    db.query(
         "insert into playlists (created, updated, userid, name, private, threadid) values (?, ?, ?, ?, ?, ?)",
         args=[
             timestamp,
@@ -814,7 +712,7 @@ def create_playlist():
             threadid
         ]
     )
-    get_db().commit()
+    db.commit()
     flash_and_log(f"Created playlist {name}", "success")
     return redirect(request.referrer)
 
@@ -824,7 +722,7 @@ def delete_playlist(playlistid):
         abort(401)
 
     # Make sure playlist exists
-    plist_data = query_db("select * from playlists where playlistid = ?", args=[playlistid], one=True)
+    plist_data = db.query("select * from playlists where playlistid = ?", args=[playlistid], one=True)
     if not plist_data:
         abort(404)
 
@@ -833,8 +731,8 @@ def delete_playlist(playlistid):
         abort(403)
 
     # Delete playlist
-    query_db("delete from playlists where playlistid = ?", args=[playlistid])
-    get_db().commit()
+    db.query("delete from playlists where playlistid = ?", args=[playlistid])
+    db.commit()
 
     flash_and_log(f"Deleted playlist {plist_data['name']}", "success")
     return redirect(f"/users/{session['username']}")
@@ -850,7 +748,7 @@ def append_to_playlist():
     except ValueError:
         abort(400)
 
-    plist_data = query_db("select * from playlists where playlistid = ?", args=[playlistid], one=True)
+    plist_data = db.query("select * from playlists where playlistid = ?", args=[playlistid], one=True)
     if not plist_data:
         abort(404)
 
@@ -861,21 +759,21 @@ def append_to_playlist():
     songid = request.form["songid"]
 
     # Make sure song exists
-    song_data = query_db("select * from songs where songid = ?", args=[songid], one=True)
+    song_data = db.query("select * from songs where songid = ?", args=[songid], one=True)
     if not song_data:
         abort(404)
 
     # Set index to count of songs in list
-    existing_songs = query_db("select * from playlist_songs where playlistid = ?", args=[playlistid])
+    existing_songs = db.query("select * from playlist_songs where playlistid = ?", args=[playlistid])
     new_position = len(existing_songs)
 
     # Add to playlist
-    query_db("insert into playlist_songs (playlistid, position, songid) values (?, ?, ?)", args=[playlistid, new_position, songid])
+    db.query("insert into playlist_songs (playlistid, position, songid) values (?, ?, ?)", args=[playlistid, new_position, songid])
 
     # Update modification time
     timestamp = datetime.now(timezone.utc).isoformat()
-    query_db("update playlists set updated = ? where playlistid = ?", args=[timestamp, playlistid])
-    get_db().commit()
+    db.query("update playlists set updated = ? where playlistid = ?", args=[timestamp, playlistid])
+    db.commit()
 
     flash_and_log(f"Added '{song_data['title']}' to {plist_data['name']}", "success")
 
@@ -887,7 +785,7 @@ def edit_playlist_post(playlistid):
         abort(401)
 
     # Make sure playlist exists
-    plist_data = query_db("select * from playlists where playlistid = ?", args=[playlistid], one=True)
+    plist_data = db.query("select * from playlists where playlistid = ?", args=[playlistid], one=True)
     if not plist_data:
         abort(404)
 
@@ -911,23 +809,23 @@ def edit_playlist_post(playlistid):
             abort(400)
 
         for songid in songids:
-            song_data = query_db("select * from songs where songid = ?", args=[songid])
+            song_data = db.query("select * from songs where songid = ?", args=[songid])
             if not song_data:
                 abort(400)
 
     # All songs valid - delete old songs
-    query_db("delete from playlist_songs where playlistid = ?", args=[playlistid])
+    db.query("delete from playlist_songs where playlistid = ?", args=[playlistid])
 
     # Re-add songs with new positions
     for position, songid in enumerate(songids):
         print(position, songid)
-        query_db("insert into playlist_songs (playlistid, position, songid) values (?, ?, ?)", args=[playlistid, position, songid])
+        db.query("insert into playlist_songs (playlistid, position, songid) values (?, ?, ?)", args=[playlistid, position, songid])
 
     # Update private, name
     private = int(request.form["type"] == "private")
-    query_db("update playlists set private = ?, name = ? where playlistid = ?", [private, name, playlistid])
+    db.query("update playlists set private = ?, name = ? where playlistid = ?", [private, name, playlistid])
 
-    get_db().commit()
+    db.commit()
 
     flash_and_log("Playlist updated", "success")
     return redirect(request.referrer)
@@ -936,7 +834,7 @@ def edit_playlist_post(playlistid):
 def playlists(playlistid):
 
     # Make sure playlist exists
-    plist_data = query_db("select * from playlists inner join users on playlists.userid = users.userid where playlistid = ?", args=[playlistid], one=True)
+    plist_data = db.query("select * from playlists inner join users on playlists.userid = users.userid where playlistid = ?", args=[playlistid], one=True)
     if not plist_data:
         abort(404)
 
@@ -949,7 +847,7 @@ def playlists(playlistid):
     songs = Song.get_for_playlist(playlistid)
 
     # Get comments
-    comments = get_comments(plist_data["threadid"])
+    plist_comments = get_comments(plist_data["threadid"])
 
     # Show page
     return render_template(
@@ -962,17 +860,7 @@ def playlists(playlistid):
             threadid=plist_data["threadid"],
             **get_user_colors(plist_data),
             songs=songs,
-            comments=comments)
-
-def flash_and_log(msg, category=None):
-    flash(msg, category)
-    username = session["username"] if "username" in session else "N/A"
-    url = request.referrer
-    logmsg = f"[{category}] User: {username}, URL: {url} - {msg}"
-    if category == "error":
-        app.logger.warning(logmsg)
-    else:
-        app.logger.info(logmsg)
+            comments=plist_comments)
 
 def sanitize_user_text(text):
         allowed_tags = bleach.sanitizer.ALLOWED_TAGS.union({
@@ -996,23 +884,18 @@ def sanitize_user_text(text):
                 attributes=allowed_attributes,
                 css_sanitizer=css_sanitizer)
 
-def create_comment_thread(threadtype, userid):
-    thread = query_db("insert into comment_threads (threadtype, userid) values (?, ?) returning threadid", [threadtype, userid], one=True)
-    get_db().commit()
-    return thread["threadid"]
-
 def get_comments(threadid):
-    comments = query_db("select * from comments inner join users on comments.userid == users.userid where comments.threadid = ?", [threadid])
-    comments = [dict(c) for c in comments]
-    for c in comments:
+    thread_comments = db.query("select * from comments inner join users on comments.userid == users.userid where comments.threadid = ?", [threadid])
+    thread_comments = [dict(c) for c in thread_comments]
+    for c in thread_comments:
         c["content"] = sanitize_user_text(c["content"])
 
     # Top-level comments
-    song_comments = sorted([dict(c) for c in comments if c["replytoid"] is None], key=lambda c: c["created"])
+    song_comments = sorted([dict(c) for c in thread_comments if c["replytoid"] is None], key=lambda c: c["created"])
     song_comments = list(reversed(song_comments))
     # Replies (can only reply to top-level)
     for comment in song_comments:
-        comment["replies"] = sorted([c for c in comments if c["replytoid"] == comment["commentid"]], key=lambda c: c["created"])
+        comment["replies"] = sorted([c for c in thread_comments if c["replytoid"] == comment["commentid"]], key=lambda c: c["created"])
 
     return song_comments
 
@@ -1031,17 +914,17 @@ def get_gif_data():
 def get_current_user_playlists():
     plist_data = []
     if "userid" in session:
-        plist_data = query_db("select * from playlists where userid = ?", [session["userid"]])
+        plist_data = db.query("select * from playlists where userid = ?", [session["userid"]])
 
     return plist_data
 
 def get_user_colors(user_data):
     if isinstance(user_data, int):
         # Get colors for userid
-        user_data = query_db("select * from users where userid = ?", [user_data], one=True)
+        user_data = db.query("select * from users where userid = ?", [user_data], one=True)
     elif isinstance(user_data, str):
         # Get colors for username
-        user_data = query_db("select * from users where username = ?", [user_data], one=True)
+        user_data = db.query("select * from users where username = ?", [user_data], one=True)
 
     colors = dict(bgcolor=BGCOLOR, fgcolor=FGCOLOR, accolor=ACCOLOR)
     for key in colors:
@@ -1051,7 +934,7 @@ def get_user_colors(user_data):
     return colors
 
 def user_has_pfp(userid):
-    return (get_user_images_path(userid)/"pfp.jpg").exists()
+    return (datadir.get_user_images_path(userid)/"pfp.jpg").exists()
 
 @app.context_processor
 def inject_global_vars():
@@ -1062,62 +945,6 @@ def inject_global_vars():
     )
 
 
-################################################################################
-# Database
-################################################################################
-
-def get_db():
-    db = getattr(g, '_database', None)
-    if db is None:
-        db = g._database = sqlite3.connect(DATA_DIR / "database.db")
-        db.cursor().execute("PRAGMA foreign_keys = ON")
-        db.row_factory = sqlite3.Row
-
-        # Get current version
-        user_version = query_db("pragma user_version", one=True)[0]
-
-        # Run update script if DB is out of date
-        schema_update_script = SCRIPT_DIR / 'sql' / 'schema_update.sql'
-        if user_version < DB_VERSION and schema_update_script.exists():
-            with app.open_resource(schema_update_script, mode='r') as f:
-                db.cursor().executescript(f.read())
-            db.commit()
-    return db
-
-# TODO: Remove after deploying
-def assign_thread_ids(db, table, id_col, threadtype):
-    cur = db.execute(f"select * from {table}")
-    for row in cur:
-        thread_cur = db.execute("insert into comment_threads (threadtype, userid) values (?, ?) returning threadid", [threadtype, row["userid"]])
-        threadid = thread_cur.fetchone()[0]
-        thread_cur.close()
-
-        song_cur = db.execute(f"update {table} set threadid = ? where {id_col} = ?", [threadid, row[id_col]])
-        song_cur.close()
-    cur.close()
-
-@app.teardown_appcontext
-def close_db(exception):
-    db = getattr(g, '_database', None)
-    if db is not None:
-        db.close()
-
-def query_db(query, args=(), one=False):
-    cur = get_db().execute(query, args)
-    rv = cur.fetchall()
-    cur.close()
-    return (rv[0] if rv else None) if one else rv
-
-@app.cli.add_command
-@click.command("init-db")
-def init_db():
-    """Clear the existing data and create new tables"""
-    with app.app_context():
-        db = sqlite3.connect(DATA_DIR / "database.db")
-        with app.open_resource(SCRIPT_DIR / 'schema.sql', mode='r') as f:
-            db.cursor().executescript(f.read())
-        db.commit()
-
 ################################################################################
 # Generate Session Key
 ################################################################################
@@ -1132,11 +959,6 @@ def gen_key():
 class ObjectType(enum.IntEnum):
     COMMENT = 0
 
-class ThreadType(enum.IntEnum):
-    SONG = 0
-    PROFILE = 1
-    PLAYLIST = 2
-
 @dataclass
 class Song:
     songid: int
@@ -1220,7 +1042,7 @@ class Song:
 
     @classmethod
     def _from_db(cls, query, args=()):
-        songs_data = query_db(query, args)
+        songs_data = db.query(query, args)
         tags, collabs = cls._get_info_for_songs(songs_data)
         songs = []
         for sd in songs_data:
@@ -1237,8 +1059,8 @@ class Song:
         collabs = {}
         for song in songs:
             songid = song["songid"]
-            tags[songid] = query_db("select (tag) from song_tags where songid = ?", [songid])
-            collabs[songid] = query_db("select (name) from song_collaborators where songid = ?", [songid])
+            tags[songid] = db.query("select (tag) from song_tags where songid = ?", [songid])
+            collabs[songid] = db.query("select (name) from song_collaborators where songid = ?", [songid])
         return tags, collabs
 
 
diff --git a/src/littlesongplace/auth.py b/src/littlesongplace/auth.py
new file mode 100644 (file)
index 0000000..06bb85c
--- /dev/null
@@ -0,0 +1,96 @@
+from datetime import datetime, timezone
+
+import bcrypt
+from flask import Blueprint, render_template, redirect, flash, request, current_app, session
+
+from . import comments, db
+from .logutils import flash_and_log
+
+bp = Blueprint("auth", __name__)
+
+@bp.get("/signup")
+def signup_get():
+    return render_template("signup.html")
+
+@bp.post("/signup")
+def signup_post():
+    username = request.form["username"]
+    password = request.form["password"]
+    password_confirm = request.form["password_confirm"]
+
+    error = False
+    if not username.isidentifier():
+        flash_and_log("Username cannot contain special characters", "error")
+        error = True
+    elif len(username) < 3:
+        flash_and_log("Username must be at least 3 characters", "error")
+        error = True
+    elif len(username) > 30:
+        flash_and_log("Username cannot be more than 30 characters", "error")
+        error = True
+
+    elif password != password_confirm:
+        flash_and_log("Passwords do not match", "error")
+        error = True
+    elif len(password) < 8:
+        flash_and_log("Password must be at least 8 characters", "error")
+        error = True
+
+    if db.query("select * from users where username = ?", [username], one=True):
+        flash_and_log(f"Username '{username}' is already taken", "error")
+        error = True
+
+    if error:
+        current_app.logger.info("Failed signup attempt")
+        return redirect(request.referrer)
+
+    password = bcrypt.hashpw(password.encode(), bcrypt.gensalt())
+    timestamp = datetime.now(timezone.utc).isoformat()
+
+    user_data = db.query("insert into users (username, password, created) values (?, ?, ?) returning userid", [username, password, timestamp], one=True)
+
+    # Create profile comment thread
+    threadid = comments.create_thread(comments.ThreadType.PROFILE, user_data["userid"])
+    db.query("update users set threadid = ? where userid = ?", [threadid, user_data["userid"]])
+    db.commit()
+
+    flash("User created.  Please sign in to continue.", "success")
+    current_app.logger.info(f"Created user {username}")
+
+    return redirect("/login")
+
+@bp.get("/login")
+def login_get():
+    return render_template("login.html")
+
+@bp.post("/login")
+def login_post():
+    username = request.form["username"]
+    password = request.form["password"]
+
+    user_data = db.query("select * from users where username = ?", [username], one=True)
+
+    if user_data and bcrypt.checkpw(password.encode(), user_data["password"]):
+        # Successful login
+        session["username"] = username
+        session["userid"] = user_data["userid"]
+        session.permanent = True
+        current_app.logger.info(f"{username} logged in")
+
+        return redirect(f"/users/{username}")
+
+    flash("Invalid username/password", "error")
+    current_app.logger.info(f"Failed login for {username}")
+
+    return render_template("login.html")
+
+
+@bp.get("/logout")
+def logout():
+    if "username" in session:
+        session.pop("username")
+    if "userid" in session:
+        session.pop("userid")
+
+    return redirect("/")
+
diff --git a/src/littlesongplace/comments.py b/src/littlesongplace/comments.py
new file mode 100644 (file)
index 0000000..e516438
--- /dev/null
@@ -0,0 +1,14 @@
+import enum
+
+from . import db
+
+def create_thread(threadtype, userid):
+    thread = db.query("insert into comment_threads (threadtype, userid) values (?, ?) returning threadid", [threadtype, userid], one=True)
+    db.commit()
+    return thread["threadid"]
+
+class ThreadType(enum.IntEnum):
+    SONG = 0
+    PROFILE = 1
+    PLAYLIST = 2
+
diff --git a/src/littlesongplace/datadir.py b/src/littlesongplace/datadir.py
new file mode 100644 (file)
index 0000000..1d98c52
--- /dev/null
@@ -0,0 +1,30 @@
+import os
+from pathlib import Path
+
+_data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else Path(".data").absolute()
+
+# Make sure _data_dir exists
+os.makedirs(_data_dir, exist_ok=True)
+
+def get_db_path():
+    return _data_dir / "database.db"
+
+def set_data_dir(newdir):
+    global _data_dir
+    _data_dir = Path(newdir)
+
+def get_user_songs_path(userid):
+    userpath = _data_dir / "songs" / str(userid)
+    if not userpath.exists():
+        os.makedirs(userpath)
+    return userpath
+
+def get_user_images_path(userid):
+    userpath = _data_dir / "images" / str(userid)
+    if not userpath.exists():
+        os.makedirs(userpath)
+    return userpath
+
+def get_app_log_path():
+    return _data_dir / "app.log"
+
diff --git a/src/littlesongplace/db.py b/src/littlesongplace/db.py
new file mode 100644 (file)
index 0000000..bdf5efd
--- /dev/null
@@ -0,0 +1,55 @@
+import sqlite3
+from pathlib import Path
+
+import click
+from flask import g, current_app
+
+from . import datadir
+
+DB_VERSION = 4
+
+def get():
+    db = getattr(g, '_database', None)
+    if db is None:
+        db = g._database = sqlite3.connect(datadir.get_db_path())
+        db.cursor().execute("PRAGMA foreign_keys = ON")
+        db.row_factory = sqlite3.Row
+
+        # Get current version
+        user_version = query("pragma user_version", one=True)[0]
+
+        # Run update script if DB is out of date
+        schema_update_script = Path(current_app.root_path) / 'sql' / 'schema_update.sql'
+        if user_version < DB_VERSION and schema_update_script.exists():
+            with current_app.open_resource(schema_update_script, mode='r') as f:
+                db.cursor().executescript(f.read())
+            db.commit()
+    return db
+
+def close(exception):
+    db = getattr(g, '_database', None)
+    if db is not None:
+        db.close()
+
+def query(query, args=(), one=False):
+    cur = get().execute(query, args)
+    rv = cur.fetchall()
+    cur.close()
+    return (rv[0] if rv else None) if one else rv
+
+def commit():
+    get().commit()
+
+@click.command("init-db")
+def init_cmd():
+    """Clear the existing data and create new tables"""
+    with current_app.app_context():
+        db = sqlite3.connect(DATA_DIR / "database.db")
+        with app.open_resource(SCRIPT_DIR / 'schema.sql', mode='r') as f:
+            db.cursor().executescript(f.read())
+        db.commit()
+
+def init_app(app):
+    app.cli.add_command(init_cmd)
+    app.teardown_appcontext(close)
+
diff --git a/src/littlesongplace/logutils.py b/src/littlesongplace/logutils.py
new file mode 100644 (file)
index 0000000..32e5a43
--- /dev/null
@@ -0,0 +1,12 @@
+from flask import current_app, request, session, flash
+
+def flash_and_log(msg, category=None):
+    flash(msg, category)
+    username = session["username"] if "username" in session else "N/A"
+    url = request.referrer
+    logmsg = f"[{category}] User: {username}, URL: {url} - {msg}"
+    if category == "error":
+        current_app.logger.warning(logmsg)
+    else:
+        current_app.logger.info(logmsg)
+
index 3e03cf06bce213b979c2c2ef26c6ce81bf7a6752..ef395a6dbac0be48eb458a793121a49e3b6f0610 100644 (file)
@@ -14,11 +14,11 @@ from .utils import login
 def app():
     # Use temporary data directory
     with tempfile.TemporaryDirectory() as data_dir:
-        lsp.DATA_DIR = Path(data_dir)
+        lsp.datadir.set_data_dir(data_dir)
 
         # Initialize Database
         with lsp.app.app_context():
-            db = sqlite3.connect(lsp.DATA_DIR / "database.db")
+            db = sqlite3.connect(lsp.datadir.get_db_path())
             with lsp.app.open_resource('sql/schema.sql', mode='r') as f:
                 db.cursor().executescript(f.read())
             db.commit()