add login and logout endpoints
Signed-off-by: Tobias Erbshäußer <tobias@tesoft.dev>
This commit is contained in:
+40
-11
@@ -1,9 +1,12 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha512"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log"
|
||||
"slices"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -14,7 +17,7 @@ type Database struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewDatabase(path string) (*Database, error) {
|
||||
func NewDatabase(path string, rootPassword string) (*Database, error) {
|
||||
log.Println("Opening database '" + path + "'")
|
||||
|
||||
db, err := sql.Open("sqlite3", path)
|
||||
@@ -22,7 +25,7 @@ func NewDatabase(path string) (*Database, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = migrate(db)
|
||||
err = migrate(db, rootPassword)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
@@ -38,6 +41,18 @@ func (db *Database) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (db *Database) ValidateRootPassword(password string) (bool, error) {
|
||||
var salt []byte
|
||||
var rootPasswordHash []byte
|
||||
err := db.db.QueryRow("SELECT salt, root_password FROM schema_info").Scan(&salt, &rootPasswordHash)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
passwordHash := sha512.Sum512(append(salt, []byte(password)...))
|
||||
return slices.Compare(passwordHash[:], rootPasswordHash) == 0, nil
|
||||
}
|
||||
|
||||
func (db *Database) CreateBlogArticle() (int64, func(*Article) error, error) {
|
||||
tx, err := db.db.Begin()
|
||||
if err != nil {
|
||||
@@ -188,10 +203,10 @@ func (db *Database) GetBlogArticles(showAll bool, offset int, limit int) ([]Arti
|
||||
return articles, nil
|
||||
}
|
||||
|
||||
func migrate(db *sql.DB) error {
|
||||
func migrate(db *sql.DB, rootPassword string) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Println("Checking database schema version")
|
||||
@@ -204,7 +219,7 @@ func migrate(db *sql.DB) error {
|
||||
if curVersion == -1 {
|
||||
log.Println("Database is empty")
|
||||
|
||||
err = createV1Tables(tx)
|
||||
err = createV1Tables(tx, rootPassword)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
@@ -213,27 +228,41 @@ func migrate(db *sql.DB) error {
|
||||
log.Println("Database schema version is", curVersion)
|
||||
|
||||
if curVersion != 1 {
|
||||
log.Fatalln("Unsupported database schema version")
|
||||
return errors.New("unsupported database schema version")
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func createV1Tables(tx *sql.Tx) error {
|
||||
func createV1Tables(tx *sql.Tx, rootPassword string) error {
|
||||
log.Println("Creating tables for schema version 1")
|
||||
|
||||
salt := make([]byte, 32)
|
||||
_, _ = rand.Read(salt)
|
||||
|
||||
if rootPassword == "" {
|
||||
return errors.New("root password is required")
|
||||
}
|
||||
rootPasswordHash := sha512.Sum512(append(salt, []byte(rootPassword)...))
|
||||
|
||||
_, err := tx.Exec(`
|
||||
CREATE TABLE schema_info(
|
||||
id INTEGER PRIMARY KEY CHECK(id = 0),
|
||||
version INTEGER NOT NULL
|
||||
version INTEGER NOT NULL,
|
||||
salt BLOB NOT NULL,
|
||||
root_password BLOB NOT NULL
|
||||
);
|
||||
INSERT INTO schema_info(
|
||||
id,
|
||||
version
|
||||
version,
|
||||
salt,
|
||||
root_password
|
||||
) VALUES (
|
||||
0,
|
||||
1
|
||||
1,
|
||||
?,
|
||||
?
|
||||
);
|
||||
|
||||
CREATE TABLE blog_tag(
|
||||
@@ -260,7 +289,7 @@ func createV1Tables(tx *sql.Tx) error {
|
||||
PRIMARY KEY(id, articleId),
|
||||
FOREIGN KEY(articleId) REFERENCES blog_article(id)
|
||||
);
|
||||
`)
|
||||
`, salt, rootPasswordHash[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user