diff --git a/backend/api.go b/backend/api.go index d331750..032512b 100644 --- a/backend/api.go +++ b/backend/api.go @@ -2,6 +2,11 @@ package main import ( "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "io" + "log" "net/http" ) @@ -13,6 +18,82 @@ type ApiHandler struct { const authTokenCookieName = "auth-token" const isAuthorizedContextKey = "is-authorized" +func (h *ApiHandler) ServeLoginPost(writer http.ResponseWriter, request *http.Request) { + bodyReader := request.Body + body, err := io.ReadAll(bodyReader) + _ = bodyReader.Close() + if err != nil { + http.Error(writer, err.Error(), http.StatusBadRequest) + return + } + + type LoginBody struct { + Password string `json:"password"` + } + + loginBody := LoginBody{} + err = json.Unmarshal(body, &loginBody) + if err != nil { + http.Error(writer, err.Error(), http.StatusBadRequest) + return + } + + success, err := h.db.ValidateRootPassword(loginBody.Password) + if err != nil { + log.Println("Error logging in:", err) + http.Error(writer, "failed to read database", http.StatusInternalServerError) + return + } + + if !success { + http.Error(writer, "invalid password", http.StatusUnauthorized) + return + } + + rawAuthToken := make([]byte, 128) + _, _ = rand.Read(rawAuthToken) + authToken := hex.EncodeToString(rawAuthToken) + h.authToken = &authToken + + cookie := http.Cookie{} + cookie.Name = authTokenCookieName + cookie.Value = authToken + cookie.Secure = true + cookie.HttpOnly = true + http.SetCookie(writer, &cookie) + + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(http.StatusOK) + err = json.NewEncoder(writer).Encode(map[string]interface{}{}) + if err != nil { + http.Error(writer, "failed to serialize results", http.StatusInternalServerError) + return + } + +} + +func (h *ApiHandler) ServeLogoutPost(writer http.ResponseWriter, request *http.Request) { + cookie, _ := request.Cookie(authTokenCookieName) + if cookie != nil { + cookie := http.Cookie{} + cookie.Name = authTokenCookieName + cookie.Value = "" + cookie.Secure = true + cookie.HttpOnly = true + http.SetCookie(writer, &cookie) + } + + h.authToken = nil + + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(http.StatusOK) + err := json.NewEncoder(writer).Encode(map[string]interface{}{}) + if err != nil { + http.Error(writer, "failed to serialize results", http.StatusInternalServerError) + return + } +} + func (h *ApiHandler) ProcessAuth(next http.Handler, required bool) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { isAuthorized := false diff --git a/backend/db.go b/backend/db.go index 59e5b04..b1ea2df 100644 --- a/backend/db.go +++ b/backend/db.go @@ -1,9 +1,12 @@ package main import ( + "crypto/rand" + "crypto/sha512" "database/sql" "errors" "log" + "slices" "strconv" "time" @@ -14,7 +17,7 @@ type Database struct { db *sql.DB } -func NewDatabase(path string) (*Database, error) { +func NewDatabase(path string, rootPassword string) (*Database, error) { log.Println("Opening database '" + path + "'") db, err := sql.Open("sqlite3", path) @@ -22,7 +25,7 @@ func NewDatabase(path string) (*Database, error) { return nil, err } - err = migrate(db) + err = migrate(db, rootPassword) if err != nil { _ = db.Close() return nil, err @@ -38,6 +41,18 @@ func (db *Database) Close() { } } +func (db *Database) ValidateRootPassword(password string) (bool, error) { + var salt []byte + var rootPasswordHash []byte + err := db.db.QueryRow("SELECT salt, root_password FROM schema_info").Scan(&salt, &rootPasswordHash) + if err != nil { + return false, err + } + + passwordHash := sha512.Sum512(append(salt, []byte(password)...)) + return slices.Compare(passwordHash[:], rootPasswordHash) == 0, nil +} + func (db *Database) CreateBlogArticle() (int64, func(*Article) error, error) { tx, err := db.db.Begin() if err != nil { @@ -188,10 +203,10 @@ func (db *Database) GetBlogArticles(showAll bool, offset int, limit int) ([]Arti return articles, nil } -func migrate(db *sql.DB) error { +func migrate(db *sql.DB, rootPassword string) error { tx, err := db.Begin() if err != nil { - log.Fatal(err) + return err } log.Println("Checking database schema version") @@ -204,7 +219,7 @@ func migrate(db *sql.DB) error { if curVersion == -1 { log.Println("Database is empty") - err = createV1Tables(tx) + err = createV1Tables(tx, rootPassword) if err != nil { _ = tx.Rollback() return err @@ -213,27 +228,41 @@ func migrate(db *sql.DB) error { log.Println("Database schema version is", curVersion) if curVersion != 1 { - log.Fatalln("Unsupported database schema version") + return errors.New("unsupported database schema version") } } return tx.Commit() } -func createV1Tables(tx *sql.Tx) error { +func createV1Tables(tx *sql.Tx, rootPassword string) error { log.Println("Creating tables for schema version 1") + salt := make([]byte, 32) + _, _ = rand.Read(salt) + + if rootPassword == "" { + return errors.New("root password is required") + } + rootPasswordHash := sha512.Sum512(append(salt, []byte(rootPassword)...)) + _, err := tx.Exec(` CREATE TABLE schema_info( id INTEGER PRIMARY KEY CHECK(id = 0), - version INTEGER NOT NULL + version INTEGER NOT NULL, + salt BLOB NOT NULL, + root_password BLOB NOT NULL ); INSERT INTO schema_info( id, - version + version, + salt, + root_password ) VALUES ( 0, - 1 + 1, + ?, + ? ); CREATE TABLE blog_tag( @@ -260,7 +289,7 @@ func createV1Tables(tx *sql.Tx) error { PRIMARY KEY(id, articleId), FOREIGN KEY(articleId) REFERENCES blog_article(id) ); - `) + `, salt, rootPasswordHash[:]) if err != nil { return err } diff --git a/backend/main.go b/backend/main.go index d6d620d..37aa649 100644 --- a/backend/main.go +++ b/backend/main.go @@ -22,7 +22,7 @@ func main() { dbPath = "./database.sqlite" } - db, err := NewDatabase(dbPath) + db, err := NewDatabase(dbPath, os.Getenv("ROOT_PASSWORD")) if err != nil { log.Fatal(err.Error()) } @@ -44,6 +44,12 @@ func main() { mux.Handle("GET /api/blog/{articleId}/file/{fileId}", apiHandler.ProcessAuth(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { apiHandler.ServeBlogFileGetSingle(writer, request) }), false)) + mux.HandleFunc("POST /api/login", func(writer http.ResponseWriter, request *http.Request) { + apiHandler.ServeLoginPost(writer, request) + }) + mux.Handle("POST /api/logout", apiHandler.ProcessAuth(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + apiHandler.ServeLogoutPost(writer, request) + }), true)) mux.HandleFunc("/api/", func(writer http.ResponseWriter, request *http.Request) { http.NotFound(writer, request) })