From 88fc729210c502a1eba6e76f5e95757781793183 Mon Sep 17 00:00:00 2001 From: Chris Fulljames Date: Sat, 5 Apr 2025 10:24:18 -0400 Subject: [PATCH] Start refactoring __init__ --- src/littlesongplace/__init__.py | 374 +++++++++----------------------- src/littlesongplace/auth.py | 96 ++++++++ src/littlesongplace/comments.py | 14 ++ src/littlesongplace/datadir.py | 30 +++ src/littlesongplace/db.py | 55 +++++ src/littlesongplace/logutils.py | 12 + test/conftest.py | 4 +- 7 files changed, 307 insertions(+), 278 deletions(-) create mode 100644 src/littlesongplace/auth.py create mode 100644 src/littlesongplace/comments.py create mode 100644 src/littlesongplace/datadir.py create mode 100644 src/littlesongplace/db.py create mode 100644 src/littlesongplace/logutils.py diff --git a/src/littlesongplace/__init__.py b/src/littlesongplace/__init__.py index 8e62746..5c3ff73 100644 --- a/src/littlesongplace/__init__.py +++ b/src/littlesongplace/__init__.py @@ -5,7 +5,6 @@ import logging import os import random import shutil -import sqlite3 import subprocess import sys import tempfile @@ -16,7 +15,6 @@ from logging.handlers import RotatingFileHandler from pathlib import Path, PosixPath from typing import Optional -import bcrypt import bleach import click from bleach.css_sanitizer import CSSSanitizer @@ -28,12 +26,8 @@ from werkzeug.middleware.proxy_fix import ProxyFix from yt_dlp import YoutubeDL from yt_dlp.utils import DownloadError -DB_VERSION = 4 -SCRIPT_DIR = Path(__file__).parent -DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else Path(".data").absolute() - -# Make sure DATA_DIR exists -os.makedirs(DATA_DIR, exist_ok=True) +from . import auth, comments, datadir, db +from .logutils import flash_and_log BGCOLOR = "#e8e6b5" FGCOLOR = "#695c73" @@ -44,7 +38,7 @@ DEFAULT_COLORS = dict(bgcolor=BGCOLOR, fgcolor=FGCOLOR, accolor=ACCOLOR) # Logging ################################################################################ -handler = RotatingFileHandler(DATA_DIR / "app.log", maxBytes=1_000_000, backupCount=10) +handler = RotatingFileHandler(datadir.get_app_log_path(), maxBytes=1_000_000, backupCount=10) handler.setLevel(logging.INFO) handler.setFormatter(logging.Formatter('[%(asctime)s] %(levelname)s in %(module)s: %(message)s')) @@ -58,6 +52,8 @@ root_logger.addHandler(handler) app = Flask(__name__) app.secret_key = os.environ["SECRET_KEY"] if "SECRET_KEY" in os.environ else "dev" app.config["MAX_CONTENT_LENGTH"] = 1 * 1024 * 1024 * 1024 +app.register_blueprint(auth.bp) +db.init_app(app) if "DATA_DIR" in os.environ: # Running on server behind proxy @@ -68,7 +64,7 @@ if "DATA_DIR" in os.environ: @app.route("/") def index(): - users = query_db("select * from users order by username asc") + users = db.query("select * from users order by username asc") users = [dict(row) for row in users] for user in users: user["has_pfp"] = user_has_pfp(user["userid"]) @@ -87,97 +83,11 @@ def index(): songs = Song.get_latest(50) return render_template("index.html", users=users, songs=songs, page_title=title) -@app.get("/signup") -def signup_get(): - return render_template("signup.html") - -@app.post("/signup") -def signup_post(): - username = request.form["username"] - password = request.form["password"] - password_confirm = request.form["password_confirm"] - - error = False - if not username.isidentifier(): - flash_and_log("Username cannot contain special characters", "error") - error = True - elif len(username) < 3: - flash_and_log("Username must be at least 3 characters", "error") - error = True - elif len(username) > 30: - flash_and_log("Username cannot be more than 30 characters", "error") - error = True - - elif password != password_confirm: - flash_and_log("Passwords do not match", "error") - error = True - elif len(password) < 8: - flash_and_log("Password must be at least 8 characters", "error") - error = True - - if query_db("select * from users where username = ?", [username], one=True): - flash_and_log(f"Username '{username}' is already taken", "error") - error = True - - if error: - app.logger.info("Failed signup attempt") - return redirect(request.referrer) - - password = bcrypt.hashpw(password.encode(), bcrypt.gensalt()) - timestamp = datetime.now(timezone.utc).isoformat() - - user_data = query_db("insert into users (username, password, created) values (?, ?, ?) returning userid", [username, password, timestamp], one=True) - - # Create profile comment thread - threadid = create_comment_thread(ThreadType.PROFILE, user_data["userid"]) - query_db("update users set threadid = ? where userid = ?", [threadid, user_data["userid"]]) - get_db().commit() - - flash("User created. Please sign in to continue.", "success") - app.logger.info(f"Created user {username}") - - return redirect("/login") - -@app.get("/login") -def login_get(): - return render_template("login.html") - -@app.post("/login") -def login_post(): - username = request.form["username"] - password = request.form["password"] - - user_data = query_db("select * from users where username = ?", [username], one=True) - - if user_data and bcrypt.checkpw(password.encode(), user_data["password"]): - # Successful login - session["username"] = username - session["userid"] = user_data["userid"] - session.permanent = True - app.logger.info(f"{username} logged in") - - return redirect(f"/users/{username}") - - flash("Invalid username/password", "error") - app.logger.info(f"Failed login for {username}") - - return render_template("login.html") - - -@app.get("/logout") -def logout(): - if "username" in session: - session.pop("username") - if "userid" in session: - session.pop("userid") - - return redirect("/") - @app.get("/users/") def users_profile(profile_username): # Look up user data for current profile - profile_data = query_db("select * from users where username = ?", [profile_username], one=True) + profile_data = db.query("select * from users where username = ?", [profile_username], one=True) if profile_data is None: abort(404) profile_userid = profile_data["userid"] @@ -186,15 +96,15 @@ def users_profile(profile_username): userid = session.get("userid", None) show_private = userid == profile_userid if show_private: - plist_data = query_db("select * from playlists where userid = ? order by updated desc", [profile_userid]) + plist_data = db.query("select * from playlists where userid = ? order by updated desc", [profile_userid]) else: - plist_data = query_db("select * from playlists where userid = ? and private = 0 order by updated desc", [profile_userid]) + plist_data = db.query("select * from playlists where userid = ? and private = 0 order by updated desc", [profile_userid]) # Get songs for current profile songs = Song.get_all_for_userid(profile_userid) # Get comments for current profile - comments = get_comments(profile_data["threadid"]) + profile_comments = get_comments(profile_data["threadid"]) # Sanitize bio profile_bio = "" @@ -209,7 +119,7 @@ def users_profile(profile_username): **get_user_colors(profile_data), playlists=plist_data, songs=songs, - comments=comments, + comments=profile_comments, threadid=profile_data["threadid"], user_has_pfp=user_has_pfp(profile_userid)) @@ -218,13 +128,13 @@ def edit_profile(): if not "userid" in session: abort(401) - query_db( + db.query( "update users set bio = ?, bgcolor = ?, fgcolor = ?, accolor = ? where userid = ?", [request.form["bio"], request.form["bgcolor"], request.form["fgcolor"], request.form["accolor"], session["userid"]]) - get_db().commit() + db.commit() if request.files["pfp"]: - pfp_path = get_user_images_path(session["userid"]) / "pfp.jpg" + pfp_path = datadir.get_user_images_path(session["userid"]) / "pfp.jpg" try: with Image.open(request.files["pfp"]) as im: @@ -263,7 +173,7 @@ def edit_profile(): @app.get("/pfp/") def pfp(userid): - return send_from_directory(get_user_images_path(userid), "pfp.jpg") + return send_from_directory(datadir.get_user_images_path(userid), "pfp.jpg") @app.get("/edit-song") def edit_song(): @@ -362,18 +272,6 @@ def validate_song_form(): return error -def get_user_songs_path(userid): - userpath = DATA_DIR / "songs" / str(userid) - if not userpath.exists(): - os.makedirs(userpath) - return userpath - -def get_user_images_path(userid): - userpath = DATA_DIR / "images" / str(userid) - if not userpath.exists(): - os.makedirs(userpath) - return userpath - def update_song(): songid = request.args["songid"] try: @@ -389,7 +287,7 @@ def update_song(): collaborators = [c.strip() for c in request.form["collabs"].split(",") if c] # Make sure song exists and the logged-in user owns it - song_data = query_db("select * from songs where songid = ?", [songid], one=True) + song_data = db.query("select * from songs where songid = ?", [songid], one=True) if song_data is None: abort(400) elif session["userid"] != song_data["userid"]: @@ -402,28 +300,28 @@ def update_song(): if passed: # Move file to permanent location - filepath = get_user_songs_path(session["userid"]) / (str(song_data["songid"]) + ".mp3") + filepath = datadir.get_user_songs_path(session["userid"]) / (str(song_data["songid"]) + ".mp3") shutil.move(tmp_file.name, filepath) else: error = True if not error: # Update songs table - query_db( + db.query( "update songs set title = ?, description = ? where songid = ?", [title, description, songid]) # Update song_tags table - query_db("delete from song_tags where songid = ?", [songid]) + db.query("delete from song_tags where songid = ?", [songid]) for tag in tags: - query_db("insert into song_tags (tag, songid) values (?, ?)", [tag, songid]) + db.query("insert into song_tags (tag, songid) values (?, ?)", [tag, songid]) # Update song_collaborators table - query_db("delete from song_collaborators where songid = ?", [songid]) + db.query("delete from song_collaborators where songid = ?", [songid]) for collab in collaborators: - query_db("insert into song_collaborators (name, songid) values (?, ?)", [collab, songid]) + db.query("insert into song_collaborators (name, songid) values (?, ?)", [collab, songid]) - get_db().commit() + db.commit() flash_and_log(f"Successfully updated '{title}'", "success") return error @@ -443,27 +341,27 @@ def create_song(): return True else: # Create comment thread - threadid = create_comment_thread(ThreadType.SONG, session["userid"]) + threadid = comments.create_thread(comments.ThreadType.SONG, session["userid"]) # Create song timestamp = datetime.now(timezone.utc).isoformat() - song_data = query_db( + song_data = db.query( "insert into songs (userid, title, description, created, threadid) values (?, ?, ?, ?, ?) returning (songid)", [session["userid"], title, description, timestamp, threadid], one=True) songid = song_data["songid"] - filepath = get_user_songs_path(session["userid"]) / (str(song_data["songid"]) + ".mp3") + filepath = datadir.get_user_songs_path(session["userid"]) / (str(song_data["songid"]) + ".mp3") # Move file to permanent location shutil.move(tmp_file.name, filepath) # Assign tags for tag in tags: - query_db("insert into song_tags (tag, songid) values (?, ?)", [tag, songid]) + db.query("insert into song_tags (tag, songid) values (?, ?)", [tag, songid]) # Assign collaborators for collab in collaborators: - query_db("insert into song_collaborators (songid, name) values (?, ?)", [songid, collab]) + db.query("insert into song_collaborators (songid, name) values (?, ?)", [songid, collab]) - get_db().commit() + db.commit() flash_and_log(f"Successfully uploaded '{title}'", "success") return False @@ -521,7 +419,7 @@ def yt_import(tmp_file, yt_url): @app.get("/delete-song/") def delete_song(songid): - song_data = query_db("select * from songs where songid = ?", [songid], one=True) + song_data = db.query("select * from songs where songid = ?", [songid], one=True) if not song_data: app.logger.warning(f"Failed song delete - {session['username']} - song doesn't exist") @@ -533,15 +431,15 @@ def delete_song(songid): abort(401) # Delete tags, collaborators - query_db("delete from song_tags where songid = ?", [songid]) - query_db("delete from song_collaborators where songid = ?", [songid]) + db.query("delete from song_tags where songid = ?", [songid]) + db.query("delete from song_collaborators where songid = ?", [songid]) # Delete song database entry - query_db("delete from songs where songid = ?", [songid]) - get_db().commit() + db.query("delete from songs where songid = ?", [songid]) + db.commit() # Delete song file from disk - songpath = DATA_DIR / "songs" / str(session["userid"]) / (str(songid) + ".mp3") + songpath = datadir.get_user_songs_path(session["userid"]) / (str(songid) + ".mp3") if songpath.exists(): os.remove(songpath) @@ -566,7 +464,7 @@ def song(userid, songid): except ValueError: abort(404) else: - return send_from_directory(DATA_DIR / "songs" / str(userid), str(songid) + ".mp3") + return send_from_directory(datadir.get_user_songs_path(userid), str(songid) + ".mp3") @app.get("/songs") def songs(): @@ -596,7 +494,7 @@ def comment(): if not "threadid" in request.args: abort(400) # Must have threadid - thread = query_db("select * from comment_threads where threadid = ?", [request.args["threadid"]], one=True) + thread = db.query("select * from comment_threads where threadid = ?", [request.args["threadid"]], one=True) if not thread: abort(404) # Invalid threadid @@ -604,7 +502,7 @@ def comment(): replyto = None if "replytoid" in request.args: replytoid = request.args["replytoid"] - replyto = query_db("select * from comments inner join users on comments.userid == users.userid where commentid = ?", [replytoid], one=True) + replyto = db.query("select * from comments inner join users on comments.userid == users.userid where commentid = ?", [replytoid], one=True) if not replyto: abort(404) # Invalid comment @@ -612,7 +510,7 @@ def comment(): comment = None if "commentid" in request.args: commentid = request.args["commentid"] - comment = query_db("select * from comments inner join users on comments.userid == users.userid where commentid = ?", [commentid], one=True) + comment = db.query("select * from comments inner join users on comments.userid == users.userid where commentid = ?", [commentid], one=True) if not comment: abort(404) # Invalid comment if comment["userid"] != session["userid"]: @@ -625,12 +523,12 @@ def comment(): song = None profile = None playlist = None - if threadtype == ThreadType.SONG: + if threadtype == comments.ThreadType.SONG: song = Song.by_threadid(request.args["threadid"]) - elif threadtype == ThreadType.PROFILE: - profile = query_db("select * from users where threadid = ?", [request.args["threadid"]], one=True) - elif threadtype == ThreadType.PLAYLIST: - profile = query_db("select * from playlists inner join users on playlists.userid = users.userid where playlists.threadid = ?", [request.args["threadid"]], one=True) + elif threadtype == comments.ThreadType.PROFILE: + profile = db.query("select * from users where threadid = ?", [request.args["threadid"]], one=True) + elif threadtype == comments.ThreadType.PLAYLIST: + profile = db.query("select * from playlists inner join users on playlists.userid = users.userid where playlists.threadid = ?", [request.args["threadid"]], one=True) return render_template( "comment.html", song=song, @@ -645,7 +543,7 @@ def comment(): content = request.form["content"] if comment: # Update existing comment - query_db("update comments set content = ? where commentid = ?", args=[content, comment["commentid"]]) + db.query("update comments set content = ? where commentid = ?", args=[content, comment["commentid"]]) else: # Add new comment timestamp = datetime.now(timezone.utc).isoformat() @@ -653,7 +551,7 @@ def comment(): replytoid = request.args.get("replytoid", None) threadid = request.args["threadid"] - comment = query_db( + comment = db.query( "insert into comments (threadid, userid, replytoid, created, content) values (?, ?, ?, ?, ?) returning (commentid)", args=[threadid, userid, replytoid, timestamp, content], one=True) commentid = comment["commentid"] @@ -665,7 +563,7 @@ def comment(): notification_targets.add(replyto["userid"]) # Notify previous repliers in thread - previous_replies = query_db("select * from comments where replytoid = ?", [replytoid]) + previous_replies = db.query("select * from comments where replytoid = ?", [replytoid]) for reply in previous_replies: notification_targets.add(reply["userid"]) @@ -675,9 +573,9 @@ def comment(): # Create notifications for target in notification_targets: - query_db("insert into notifications (objectid, objecttype, targetuserid, created) values (?, ?, ?, ?)", [commentid, ObjectType.COMMENT, target, timestamp]) + db.query("insert into notifications (objectid, objecttype, targetuserid, created) values (?, ?, ?, ?)", [commentid, ObjectType.COMMENT, target, timestamp]) - get_db().commit() + db.commit() return redirect_to_previous_page() @@ -693,7 +591,7 @@ def comment_delete(commentid): if "userid" not in session: return redirect("/login") - comment = query_db("select c.userid as comment_user, t.userid as thread_user from comments as c inner join comment_threads as t on c.threadid == t.threadid where commentid = ?", [commentid], one=True) + comment = db.query("select c.userid as comment_user, t.userid as thread_user from comments as c inner join comment_threads as t on c.threadid == t.threadid where commentid = ?", [commentid], one=True) if not comment: abort(404) # Invalid comment @@ -702,8 +600,8 @@ def comment_delete(commentid): or (comment["thread_user"] == session["userid"])): abort(403) - query_db("delete from comments where (commentid = ?) or (replytoid = ?)", [commentid, commentid]) - get_db().commit() + db.query("delete from comments where (commentid = ?) or (replytoid = ?)", [commentid, commentid]) + db.commit() return redirect(request.referrer) @@ -713,7 +611,7 @@ def activity(): return redirect("/login") # Get comment notifications - comments = query_db( + notifications = db.query( """\ select c.content, c.commentid, c.replytoid, cu.username as comment_username, rc.content as replyto_content, c.threadid, t.threadtype from notifications as n @@ -726,21 +624,21 @@ def activity(): """, [session["userid"], ObjectType.COMMENT]) - comments = [dict(c) for c in comments] - for comment in comments: + notifications = [dict(c) for c in notifications] + for comment in notifications: threadtype = comment["threadtype"] - if threadtype == ThreadType.SONG: + if threadtype == comments.ThreadType.SONG: song = Song.by_threadid(comment["threadid"]) comment["songid"] = song.songid comment["title"] = song.title comment["content_userid"] = song.userid comment["content_username"] = song.username - elif threadtype == ThreadType.PROFILE: - profile = query_db("select * from users where threadid = ?", [comment["threadid"]], one=True) + elif threadtype == comments.ThreadType.PROFILE: + profile = db.query("select * from users where threadid = ?", [comment["threadid"]], one=True) comment["content_userid"] = profile["userid"] comment["content_username"] = profile["username"] - elif threadtype == ThreadType.PLAYLIST: - playlist = query_db( + elif threadtype == comments.ThreadType.PLAYLIST: + playlist = db.query( """\ select * from playlists inner join users on playlists.userid == users.userid @@ -755,17 +653,17 @@ def activity(): comment["content_username"] = playlist["username"] timestamp = datetime.now(timezone.utc).isoformat() - query_db("update users set activitytime = ? where userid = ?", [timestamp, session["userid"]]) - get_db().commit() + db.query("update users set activitytime = ? where userid = ?", [timestamp, session["userid"]]) + db.commit() - return render_template("activity.html", comments=comments) + return render_template("activity.html", comments=notifications) @app.get("/new-activity") def new_activity(): has_new_activity = False if "userid" in session: - user_data = query_db("select activitytime from users where userid = ?", [session["userid"]], one=True) - comment_data = query_db( + user_data = db.query("select activitytime from users where userid = ?", [session["userid"]], one=True) + comment_data = db.query( """\ select created from notifications where targetuserid = ? @@ -801,9 +699,9 @@ def create_playlist(): private = request.form["type"] == "private" - threadid = create_comment_thread(ThreadType.PLAYLIST, session["userid"]) + threadid = comments.create_thread(comments.ThreadType.PLAYLIST, session["userid"]) - query_db( + db.query( "insert into playlists (created, updated, userid, name, private, threadid) values (?, ?, ?, ?, ?, ?)", args=[ timestamp, @@ -814,7 +712,7 @@ def create_playlist(): threadid ] ) - get_db().commit() + db.commit() flash_and_log(f"Created playlist {name}", "success") return redirect(request.referrer) @@ -824,7 +722,7 @@ def delete_playlist(playlistid): abort(401) # Make sure playlist exists - plist_data = query_db("select * from playlists where playlistid = ?", args=[playlistid], one=True) + plist_data = db.query("select * from playlists where playlistid = ?", args=[playlistid], one=True) if not plist_data: abort(404) @@ -833,8 +731,8 @@ def delete_playlist(playlistid): abort(403) # Delete playlist - query_db("delete from playlists where playlistid = ?", args=[playlistid]) - get_db().commit() + db.query("delete from playlists where playlistid = ?", args=[playlistid]) + db.commit() flash_and_log(f"Deleted playlist {plist_data['name']}", "success") return redirect(f"/users/{session['username']}") @@ -850,7 +748,7 @@ def append_to_playlist(): except ValueError: abort(400) - plist_data = query_db("select * from playlists where playlistid = ?", args=[playlistid], one=True) + plist_data = db.query("select * from playlists where playlistid = ?", args=[playlistid], one=True) if not plist_data: abort(404) @@ -861,21 +759,21 @@ def append_to_playlist(): songid = request.form["songid"] # Make sure song exists - song_data = query_db("select * from songs where songid = ?", args=[songid], one=True) + song_data = db.query("select * from songs where songid = ?", args=[songid], one=True) if not song_data: abort(404) # Set index to count of songs in list - existing_songs = query_db("select * from playlist_songs where playlistid = ?", args=[playlistid]) + existing_songs = db.query("select * from playlist_songs where playlistid = ?", args=[playlistid]) new_position = len(existing_songs) # Add to playlist - query_db("insert into playlist_songs (playlistid, position, songid) values (?, ?, ?)", args=[playlistid, new_position, songid]) + db.query("insert into playlist_songs (playlistid, position, songid) values (?, ?, ?)", args=[playlistid, new_position, songid]) # Update modification time timestamp = datetime.now(timezone.utc).isoformat() - query_db("update playlists set updated = ? where playlistid = ?", args=[timestamp, playlistid]) - get_db().commit() + db.query("update playlists set updated = ? where playlistid = ?", args=[timestamp, playlistid]) + db.commit() flash_and_log(f"Added '{song_data['title']}' to {plist_data['name']}", "success") @@ -887,7 +785,7 @@ def edit_playlist_post(playlistid): abort(401) # Make sure playlist exists - plist_data = query_db("select * from playlists where playlistid = ?", args=[playlistid], one=True) + plist_data = db.query("select * from playlists where playlistid = ?", args=[playlistid], one=True) if not plist_data: abort(404) @@ -911,23 +809,23 @@ def edit_playlist_post(playlistid): abort(400) for songid in songids: - song_data = query_db("select * from songs where songid = ?", args=[songid]) + song_data = db.query("select * from songs where songid = ?", args=[songid]) if not song_data: abort(400) # All songs valid - delete old songs - query_db("delete from playlist_songs where playlistid = ?", args=[playlistid]) + db.query("delete from playlist_songs where playlistid = ?", args=[playlistid]) # Re-add songs with new positions for position, songid in enumerate(songids): print(position, songid) - query_db("insert into playlist_songs (playlistid, position, songid) values (?, ?, ?)", args=[playlistid, position, songid]) + db.query("insert into playlist_songs (playlistid, position, songid) values (?, ?, ?)", args=[playlistid, position, songid]) # Update private, name private = int(request.form["type"] == "private") - query_db("update playlists set private = ?, name = ? where playlistid = ?", [private, name, playlistid]) + db.query("update playlists set private = ?, name = ? where playlistid = ?", [private, name, playlistid]) - get_db().commit() + db.commit() flash_and_log("Playlist updated", "success") return redirect(request.referrer) @@ -936,7 +834,7 @@ def edit_playlist_post(playlistid): def playlists(playlistid): # Make sure playlist exists - plist_data = query_db("select * from playlists inner join users on playlists.userid = users.userid where playlistid = ?", args=[playlistid], one=True) + plist_data = db.query("select * from playlists inner join users on playlists.userid = users.userid where playlistid = ?", args=[playlistid], one=True) if not plist_data: abort(404) @@ -949,7 +847,7 @@ def playlists(playlistid): songs = Song.get_for_playlist(playlistid) # Get comments - comments = get_comments(plist_data["threadid"]) + plist_comments = get_comments(plist_data["threadid"]) # Show page return render_template( @@ -962,17 +860,7 @@ def playlists(playlistid): threadid=plist_data["threadid"], **get_user_colors(plist_data), songs=songs, - comments=comments) - -def flash_and_log(msg, category=None): - flash(msg, category) - username = session["username"] if "username" in session else "N/A" - url = request.referrer - logmsg = f"[{category}] User: {username}, URL: {url} - {msg}" - if category == "error": - app.logger.warning(logmsg) - else: - app.logger.info(logmsg) + comments=plist_comments) def sanitize_user_text(text): allowed_tags = bleach.sanitizer.ALLOWED_TAGS.union({ @@ -996,23 +884,18 @@ def sanitize_user_text(text): attributes=allowed_attributes, css_sanitizer=css_sanitizer) -def create_comment_thread(threadtype, userid): - thread = query_db("insert into comment_threads (threadtype, userid) values (?, ?) returning threadid", [threadtype, userid], one=True) - get_db().commit() - return thread["threadid"] - def get_comments(threadid): - comments = query_db("select * from comments inner join users on comments.userid == users.userid where comments.threadid = ?", [threadid]) - comments = [dict(c) for c in comments] - for c in comments: + thread_comments = db.query("select * from comments inner join users on comments.userid == users.userid where comments.threadid = ?", [threadid]) + thread_comments = [dict(c) for c in thread_comments] + for c in thread_comments: c["content"] = sanitize_user_text(c["content"]) # Top-level comments - song_comments = sorted([dict(c) for c in comments if c["replytoid"] is None], key=lambda c: c["created"]) + song_comments = sorted([dict(c) for c in thread_comments if c["replytoid"] is None], key=lambda c: c["created"]) song_comments = list(reversed(song_comments)) # Replies (can only reply to top-level) for comment in song_comments: - comment["replies"] = sorted([c for c in comments if c["replytoid"] == comment["commentid"]], key=lambda c: c["created"]) + comment["replies"] = sorted([c for c in thread_comments if c["replytoid"] == comment["commentid"]], key=lambda c: c["created"]) return song_comments @@ -1031,17 +914,17 @@ def get_gif_data(): def get_current_user_playlists(): plist_data = [] if "userid" in session: - plist_data = query_db("select * from playlists where userid = ?", [session["userid"]]) + plist_data = db.query("select * from playlists where userid = ?", [session["userid"]]) return plist_data def get_user_colors(user_data): if isinstance(user_data, int): # Get colors for userid - user_data = query_db("select * from users where userid = ?", [user_data], one=True) + user_data = db.query("select * from users where userid = ?", [user_data], one=True) elif isinstance(user_data, str): # Get colors for username - user_data = query_db("select * from users where username = ?", [user_data], one=True) + user_data = db.query("select * from users where username = ?", [user_data], one=True) colors = dict(bgcolor=BGCOLOR, fgcolor=FGCOLOR, accolor=ACCOLOR) for key in colors: @@ -1051,7 +934,7 @@ def get_user_colors(user_data): return colors def user_has_pfp(userid): - return (get_user_images_path(userid)/"pfp.jpg").exists() + return (datadir.get_user_images_path(userid)/"pfp.jpg").exists() @app.context_processor def inject_global_vars(): @@ -1062,62 +945,6 @@ def inject_global_vars(): ) -################################################################################ -# Database -################################################################################ - -def get_db(): - db = getattr(g, '_database', None) - if db is None: - db = g._database = sqlite3.connect(DATA_DIR / "database.db") - db.cursor().execute("PRAGMA foreign_keys = ON") - db.row_factory = sqlite3.Row - - # Get current version - user_version = query_db("pragma user_version", one=True)[0] - - # Run update script if DB is out of date - schema_update_script = SCRIPT_DIR / 'sql' / 'schema_update.sql' - if user_version < DB_VERSION and schema_update_script.exists(): - with app.open_resource(schema_update_script, mode='r') as f: - db.cursor().executescript(f.read()) - db.commit() - return db - -# TODO: Remove after deploying -def assign_thread_ids(db, table, id_col, threadtype): - cur = db.execute(f"select * from {table}") - for row in cur: - thread_cur = db.execute("insert into comment_threads (threadtype, userid) values (?, ?) returning threadid", [threadtype, row["userid"]]) - threadid = thread_cur.fetchone()[0] - thread_cur.close() - - song_cur = db.execute(f"update {table} set threadid = ? where {id_col} = ?", [threadid, row[id_col]]) - song_cur.close() - cur.close() - -@app.teardown_appcontext -def close_db(exception): - db = getattr(g, '_database', None) - if db is not None: - db.close() - -def query_db(query, args=(), one=False): - cur = get_db().execute(query, args) - rv = cur.fetchall() - cur.close() - return (rv[0] if rv else None) if one else rv - -@app.cli.add_command -@click.command("init-db") -def init_db(): - """Clear the existing data and create new tables""" - with app.app_context(): - db = sqlite3.connect(DATA_DIR / "database.db") - with app.open_resource(SCRIPT_DIR / 'schema.sql', mode='r') as f: - db.cursor().executescript(f.read()) - db.commit() - ################################################################################ # Generate Session Key ################################################################################ @@ -1132,11 +959,6 @@ def gen_key(): class ObjectType(enum.IntEnum): COMMENT = 0 -class ThreadType(enum.IntEnum): - SONG = 0 - PROFILE = 1 - PLAYLIST = 2 - @dataclass class Song: songid: int @@ -1220,7 +1042,7 @@ class Song: @classmethod def _from_db(cls, query, args=()): - songs_data = query_db(query, args) + songs_data = db.query(query, args) tags, collabs = cls._get_info_for_songs(songs_data) songs = [] for sd in songs_data: @@ -1237,8 +1059,8 @@ class Song: collabs = {} for song in songs: songid = song["songid"] - tags[songid] = query_db("select (tag) from song_tags where songid = ?", [songid]) - collabs[songid] = query_db("select (name) from song_collaborators where songid = ?", [songid]) + tags[songid] = db.query("select (tag) from song_tags where songid = ?", [songid]) + collabs[songid] = db.query("select (name) from song_collaborators where songid = ?", [songid]) return tags, collabs diff --git a/src/littlesongplace/auth.py b/src/littlesongplace/auth.py new file mode 100644 index 0000000..06bb85c --- /dev/null +++ b/src/littlesongplace/auth.py @@ -0,0 +1,96 @@ +from datetime import datetime, timezone + +import bcrypt +from flask import Blueprint, render_template, redirect, flash, request, current_app, session + +from . import comments, db +from .logutils import flash_and_log + +bp = Blueprint("auth", __name__) + +@bp.get("/signup") +def signup_get(): + return render_template("signup.html") + +@bp.post("/signup") +def signup_post(): + username = request.form["username"] + password = request.form["password"] + password_confirm = request.form["password_confirm"] + + error = False + if not username.isidentifier(): + flash_and_log("Username cannot contain special characters", "error") + error = True + elif len(username) < 3: + flash_and_log("Username must be at least 3 characters", "error") + error = True + elif len(username) > 30: + flash_and_log("Username cannot be more than 30 characters", "error") + error = True + + elif password != password_confirm: + flash_and_log("Passwords do not match", "error") + error = True + elif len(password) < 8: + flash_and_log("Password must be at least 8 characters", "error") + error = True + + if db.query("select * from users where username = ?", [username], one=True): + flash_and_log(f"Username '{username}' is already taken", "error") + error = True + + if error: + current_app.logger.info("Failed signup attempt") + return redirect(request.referrer) + + password = bcrypt.hashpw(password.encode(), bcrypt.gensalt()) + timestamp = datetime.now(timezone.utc).isoformat() + + user_data = db.query("insert into users (username, password, created) values (?, ?, ?) returning userid", [username, password, timestamp], one=True) + + # Create profile comment thread + threadid = comments.create_thread(comments.ThreadType.PROFILE, user_data["userid"]) + db.query("update users set threadid = ? where userid = ?", [threadid, user_data["userid"]]) + db.commit() + + flash("User created. Please sign in to continue.", "success") + current_app.logger.info(f"Created user {username}") + + return redirect("/login") + +@bp.get("/login") +def login_get(): + return render_template("login.html") + +@bp.post("/login") +def login_post(): + username = request.form["username"] + password = request.form["password"] + + user_data = db.query("select * from users where username = ?", [username], one=True) + + if user_data and bcrypt.checkpw(password.encode(), user_data["password"]): + # Successful login + session["username"] = username + session["userid"] = user_data["userid"] + session.permanent = True + current_app.logger.info(f"{username} logged in") + + return redirect(f"/users/{username}") + + flash("Invalid username/password", "error") + current_app.logger.info(f"Failed login for {username}") + + return render_template("login.html") + + +@bp.get("/logout") +def logout(): + if "username" in session: + session.pop("username") + if "userid" in session: + session.pop("userid") + + return redirect("/") + diff --git a/src/littlesongplace/comments.py b/src/littlesongplace/comments.py new file mode 100644 index 0000000..e516438 --- /dev/null +++ b/src/littlesongplace/comments.py @@ -0,0 +1,14 @@ +import enum + +from . import db + +def create_thread(threadtype, userid): + thread = db.query("insert into comment_threads (threadtype, userid) values (?, ?) returning threadid", [threadtype, userid], one=True) + db.commit() + return thread["threadid"] + +class ThreadType(enum.IntEnum): + SONG = 0 + PROFILE = 1 + PLAYLIST = 2 + diff --git a/src/littlesongplace/datadir.py b/src/littlesongplace/datadir.py new file mode 100644 index 0000000..1d98c52 --- /dev/null +++ b/src/littlesongplace/datadir.py @@ -0,0 +1,30 @@ +import os +from pathlib import Path + +_data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else Path(".data").absolute() + +# Make sure _data_dir exists +os.makedirs(_data_dir, exist_ok=True) + +def get_db_path(): + return _data_dir / "database.db" + +def set_data_dir(newdir): + global _data_dir + _data_dir = Path(newdir) + +def get_user_songs_path(userid): + userpath = _data_dir / "songs" / str(userid) + if not userpath.exists(): + os.makedirs(userpath) + return userpath + +def get_user_images_path(userid): + userpath = _data_dir / "images" / str(userid) + if not userpath.exists(): + os.makedirs(userpath) + return userpath + +def get_app_log_path(): + return _data_dir / "app.log" + diff --git a/src/littlesongplace/db.py b/src/littlesongplace/db.py new file mode 100644 index 0000000..bdf5efd --- /dev/null +++ b/src/littlesongplace/db.py @@ -0,0 +1,55 @@ +import sqlite3 +from pathlib import Path + +import click +from flask import g, current_app + +from . import datadir + +DB_VERSION = 4 + +def get(): + db = getattr(g, '_database', None) + if db is None: + db = g._database = sqlite3.connect(datadir.get_db_path()) + db.cursor().execute("PRAGMA foreign_keys = ON") + db.row_factory = sqlite3.Row + + # Get current version + user_version = query("pragma user_version", one=True)[0] + + # Run update script if DB is out of date + schema_update_script = Path(current_app.root_path) / 'sql' / 'schema_update.sql' + if user_version < DB_VERSION and schema_update_script.exists(): + with current_app.open_resource(schema_update_script, mode='r') as f: + db.cursor().executescript(f.read()) + db.commit() + return db + +def close(exception): + db = getattr(g, '_database', None) + if db is not None: + db.close() + +def query(query, args=(), one=False): + cur = get().execute(query, args) + rv = cur.fetchall() + cur.close() + return (rv[0] if rv else None) if one else rv + +def commit(): + get().commit() + +@click.command("init-db") +def init_cmd(): + """Clear the existing data and create new tables""" + with current_app.app_context(): + db = sqlite3.connect(DATA_DIR / "database.db") + with app.open_resource(SCRIPT_DIR / 'schema.sql', mode='r') as f: + db.cursor().executescript(f.read()) + db.commit() + +def init_app(app): + app.cli.add_command(init_cmd) + app.teardown_appcontext(close) + diff --git a/src/littlesongplace/logutils.py b/src/littlesongplace/logutils.py new file mode 100644 index 0000000..32e5a43 --- /dev/null +++ b/src/littlesongplace/logutils.py @@ -0,0 +1,12 @@ +from flask import current_app, request, session, flash + +def flash_and_log(msg, category=None): + flash(msg, category) + username = session["username"] if "username" in session else "N/A" + url = request.referrer + logmsg = f"[{category}] User: {username}, URL: {url} - {msg}" + if category == "error": + current_app.logger.warning(logmsg) + else: + current_app.logger.info(logmsg) + diff --git a/test/conftest.py b/test/conftest.py index 3e03cf0..ef395a6 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -14,11 +14,11 @@ from .utils import login def app(): # Use temporary data directory with tempfile.TemporaryDirectory() as data_dir: - lsp.DATA_DIR = Path(data_dir) + lsp.datadir.set_data_dir(data_dir) # Initialize Database with lsp.app.app_context(): - db = sqlite3.connect(lsp.DATA_DIR / "database.db") + db = sqlite3.connect(lsp.datadir.get_db_path()) with lsp.app.open_resource('sql/schema.sql', mode='r') as f: db.cursor().executescript(f.read()) db.commit() -- 2.39.5