add login and logout endpoints
Signed-off-by: Tobias Erbshäußer <tobias@tesoft.dev>
This commit is contained in:
@@ -2,6 +2,11 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -13,6 +18,82 @@ type ApiHandler struct {
|
||||
const authTokenCookieName = "auth-token"
|
||||
const isAuthorizedContextKey = "is-authorized"
|
||||
|
||||
func (h *ApiHandler) ServeLoginPost(writer http.ResponseWriter, request *http.Request) {
|
||||
bodyReader := request.Body
|
||||
body, err := io.ReadAll(bodyReader)
|
||||
_ = bodyReader.Close()
|
||||
if err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
type LoginBody struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
loginBody := LoginBody{}
|
||||
err = json.Unmarshal(body, &loginBody)
|
||||
if err != nil {
|
||||
http.Error(writer, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
success, err := h.db.ValidateRootPassword(loginBody.Password)
|
||||
if err != nil {
|
||||
log.Println("Error logging in:", err)
|
||||
http.Error(writer, "failed to read database", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !success {
|
||||
http.Error(writer, "invalid password", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
rawAuthToken := make([]byte, 128)
|
||||
_, _ = rand.Read(rawAuthToken)
|
||||
authToken := hex.EncodeToString(rawAuthToken)
|
||||
h.authToken = &authToken
|
||||
|
||||
cookie := http.Cookie{}
|
||||
cookie.Name = authTokenCookieName
|
||||
cookie.Value = authToken
|
||||
cookie.Secure = true
|
||||
cookie.HttpOnly = true
|
||||
http.SetCookie(writer, &cookie)
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
err = json.NewEncoder(writer).Encode(map[string]interface{}{})
|
||||
if err != nil {
|
||||
http.Error(writer, "failed to serialize results", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (h *ApiHandler) ServeLogoutPost(writer http.ResponseWriter, request *http.Request) {
|
||||
cookie, _ := request.Cookie(authTokenCookieName)
|
||||
if cookie != nil {
|
||||
cookie := http.Cookie{}
|
||||
cookie.Name = authTokenCookieName
|
||||
cookie.Value = ""
|
||||
cookie.Secure = true
|
||||
cookie.HttpOnly = true
|
||||
http.SetCookie(writer, &cookie)
|
||||
}
|
||||
|
||||
h.authToken = nil
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
err := json.NewEncoder(writer).Encode(map[string]interface{}{})
|
||||
if err != nil {
|
||||
http.Error(writer, "failed to serialize results", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ApiHandler) ProcessAuth(next http.Handler, required bool) http.Handler {
|
||||
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
isAuthorized := false
|
||||
|
||||
+40
-11
@@ -1,9 +1,12 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha512"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log"
|
||||
"slices"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -14,7 +17,7 @@ type Database struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewDatabase(path string) (*Database, error) {
|
||||
func NewDatabase(path string, rootPassword string) (*Database, error) {
|
||||
log.Println("Opening database '" + path + "'")
|
||||
|
||||
db, err := sql.Open("sqlite3", path)
|
||||
@@ -22,7 +25,7 @@ func NewDatabase(path string) (*Database, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = migrate(db)
|
||||
err = migrate(db, rootPassword)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
@@ -38,6 +41,18 @@ func (db *Database) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (db *Database) ValidateRootPassword(password string) (bool, error) {
|
||||
var salt []byte
|
||||
var rootPasswordHash []byte
|
||||
err := db.db.QueryRow("SELECT salt, root_password FROM schema_info").Scan(&salt, &rootPasswordHash)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
passwordHash := sha512.Sum512(append(salt, []byte(password)...))
|
||||
return slices.Compare(passwordHash[:], rootPasswordHash) == 0, nil
|
||||
}
|
||||
|
||||
func (db *Database) CreateBlogArticle() (int64, func(*Article) error, error) {
|
||||
tx, err := db.db.Begin()
|
||||
if err != nil {
|
||||
@@ -188,10 +203,10 @@ func (db *Database) GetBlogArticles(showAll bool, offset int, limit int) ([]Arti
|
||||
return articles, nil
|
||||
}
|
||||
|
||||
func migrate(db *sql.DB) error {
|
||||
func migrate(db *sql.DB, rootPassword string) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Println("Checking database schema version")
|
||||
@@ -204,7 +219,7 @@ func migrate(db *sql.DB) error {
|
||||
if curVersion == -1 {
|
||||
log.Println("Database is empty")
|
||||
|
||||
err = createV1Tables(tx)
|
||||
err = createV1Tables(tx, rootPassword)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
@@ -213,27 +228,41 @@ func migrate(db *sql.DB) error {
|
||||
log.Println("Database schema version is", curVersion)
|
||||
|
||||
if curVersion != 1 {
|
||||
log.Fatalln("Unsupported database schema version")
|
||||
return errors.New("unsupported database schema version")
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func createV1Tables(tx *sql.Tx) error {
|
||||
func createV1Tables(tx *sql.Tx, rootPassword string) error {
|
||||
log.Println("Creating tables for schema version 1")
|
||||
|
||||
salt := make([]byte, 32)
|
||||
_, _ = rand.Read(salt)
|
||||
|
||||
if rootPassword == "" {
|
||||
return errors.New("root password is required")
|
||||
}
|
||||
rootPasswordHash := sha512.Sum512(append(salt, []byte(rootPassword)...))
|
||||
|
||||
_, err := tx.Exec(`
|
||||
CREATE TABLE schema_info(
|
||||
id INTEGER PRIMARY KEY CHECK(id = 0),
|
||||
version INTEGER NOT NULL
|
||||
version INTEGER NOT NULL,
|
||||
salt BLOB NOT NULL,
|
||||
root_password BLOB NOT NULL
|
||||
);
|
||||
INSERT INTO schema_info(
|
||||
id,
|
||||
version
|
||||
version,
|
||||
salt,
|
||||
root_password
|
||||
) VALUES (
|
||||
0,
|
||||
1
|
||||
1,
|
||||
?,
|
||||
?
|
||||
);
|
||||
|
||||
CREATE TABLE blog_tag(
|
||||
@@ -260,7 +289,7 @@ func createV1Tables(tx *sql.Tx) error {
|
||||
PRIMARY KEY(id, articleId),
|
||||
FOREIGN KEY(articleId) REFERENCES blog_article(id)
|
||||
);
|
||||
`)
|
||||
`, salt, rootPasswordHash[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
+7
-1
@@ -22,7 +22,7 @@ func main() {
|
||||
dbPath = "./database.sqlite"
|
||||
}
|
||||
|
||||
db, err := NewDatabase(dbPath)
|
||||
db, err := NewDatabase(dbPath, os.Getenv("ROOT_PASSWORD"))
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
@@ -44,6 +44,12 @@ func main() {
|
||||
mux.Handle("GET /api/blog/{articleId}/file/{fileId}", apiHandler.ProcessAuth(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
apiHandler.ServeBlogFileGetSingle(writer, request)
|
||||
}), false))
|
||||
mux.HandleFunc("POST /api/login", func(writer http.ResponseWriter, request *http.Request) {
|
||||
apiHandler.ServeLoginPost(writer, request)
|
||||
})
|
||||
mux.Handle("POST /api/logout", apiHandler.ProcessAuth(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
apiHandler.ServeLogoutPost(writer, request)
|
||||
}), true))
|
||||
mux.HandleFunc("/api/", func(writer http.ResponseWriter, request *http.Request) {
|
||||
http.NotFound(writer, request)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user