Files
website/backend/db.go
T
2026-05-24 09:40:21 +02:00

420 lines
9.0 KiB
Go

package main
import (
"crypto/rand"
"crypto/sha512"
"database/sql"
"errors"
"log"
"slices"
"strconv"
"time"
_ "github.com/mattn/go-sqlite3"
)
var ErrNotFound = errors.New("not found")
type Database struct {
db *sql.DB
}
func NewDatabase(path string, rootPassword string) (*Database, error) {
log.Println("opening database '" + path + "'")
db, err := sql.Open("sqlite3", path)
if err != nil {
return nil, err
}
err = migrate(db, rootPassword)
if err != nil {
_ = db.Close()
return nil, err
}
return &Database{db}, nil
}
func (db *Database) Close() {
err := db.db.Close()
if err != nil {
log.Println("error closing database:", err.Error())
}
}
func (db *Database) ValidateRootPassword(password string) (bool, error) {
var salt []byte
var rootPasswordHash []byte
err := db.db.QueryRow("SELECT salt, root_password FROM schema_info").Scan(&salt, &rootPasswordHash)
if err != nil {
return false, err
}
passwordHash := sha512.Sum512(append(salt, []byte(password)...))
return slices.Compare(passwordHash[:], rootPasswordHash) == 0, nil
}
func (db *Database) CreateBlogArticle() (int64, func(*Article) error, error) {
tx, err := db.db.Begin()
if err != nil {
return -1, nil, err
}
res, err := tx.Exec(
"INSERT INTO blog_article(status, title, date, content) VALUES (?, ?, ?, ?)",
ArticleStatusDraft,
"",
time.DateOnly,
"",
)
if err != nil {
return -1, nil, err
}
id, err := res.LastInsertId()
if err != nil {
return -1, nil, err
}
return id, func(article *Article) error {
if article == nil {
return tx.Rollback()
}
var modificationDate *string
if article.ModificationDate != nil {
tmp := article.ModificationDate.Format(time.DateOnly)
modificationDate = &tmp
}
_, err := tx.Exec(
"UPDATE blog_article SET title = ?, date = ?, modification_date = ?, content = ? WHERE id = ?",
article.Title,
article.ReleaseDate.Format(time.DateOnly),
modificationDate,
article.Content,
id,
)
if err != nil {
_ = tx.Rollback()
return err
}
for _, tag := range article.Tags {
tagId, err := createOrGetTag(tx, tag)
if err != nil {
_ = tx.Rollback()
return err
}
_, err = tx.Exec(
"INSERT INTO blog_article_to_tag(tag_id, article_id) VALUES (?, ?)",
tagId,
id,
)
if err != nil {
return err
}
}
for _, file := range article.Files {
_, err = tx.Exec(
"INSERT INTO blog_file(id, article_id, data) VALUES (?, ?, ?)",
file.Id,
id,
file.Data,
)
if err != nil {
return err
}
}
return tx.Commit()
}, err
}
func (db *Database) GetBlogArticles(showAll bool, offset int, limit int) ([]ArticleProperties, error) {
inner := "SELECT id FROM blog_article"
if !showAll {
inner = inner + " WHERE status = " + strconv.Itoa(ArticleStatusPublished)
}
inner = inner + " ORDER BY date DESC LIMIT ? OFFSET ?"
outer := "SELECT blog_article.id, blog_article.status, blog_article.title, blog_article.date, blog_article.modification_date, blog_tag.name" +
" FROM blog_article" +
" LEFT JOIN blog_article_to_tag ON blog_article.id = blog_article_to_tag.article_id" +
" LEFT JOIN blog_tag ON blog_article_to_tag.tag_id = blog_tag.id" +
" WHERE blog_article.id IN (" + inner + ")" +
" ORDER BY blog_article.date DESC, blog_article.id"
rows, err := db.db.Query(
outer,
limit,
offset,
)
if err != nil {
return nil, err
}
articles := make([]ArticleProperties, 0)
for rows.Next() {
var id int64
var status ArticleStatus
var title string
var dateStr string
var modificationDateStr *string
var tag *string
err := rows.Scan(&id, &status, &title, &dateStr, &modificationDateStr, &tag)
if err != nil {
_ = rows.Close()
return nil, err
}
if tag != nil && len(articles) > 0 && articles[len(articles)-1].Id == id {
articles[len(articles)-1].Tags = append(articles[len(articles)-1].Tags, *tag)
continue
}
date, _ := time.Parse(time.DateOnly, dateStr)
var modificationDate *time.Time
if modificationDateStr != nil {
tmp, _ := time.Parse(time.DateOnly, *modificationDateStr)
modificationDate = &tmp
}
tags := make([]string, 0)
if tag != nil {
tags = append(tags, *tag)
}
articles = append(articles, ArticleProperties{
id,
title,
status,
tags,
date,
modificationDate,
})
}
err = rows.Close()
if err != nil {
return nil, err
}
return articles, nil
}
func (db *Database) GetBlogArticle(showAll bool, id int64) (*Article, error) {
filter := "WHERE blog_article.id = ?"
if !showAll {
filter = filter + " AND status = " + strconv.Itoa(ArticleStatusPublished)
}
statement := "SELECT blog_article.status, blog_article.title, blog_article.date, blog_article.modification_date, blog_article.content, blog_tag.name" +
" FROM blog_article" +
" LEFT JOIN blog_article_to_tag ON blog_article.id = blog_article_to_tag.article_id" +
" LEFT JOIN blog_tag ON blog_article_to_tag.tag_id = blog_tag.id" +
" " + filter +
" ORDER BY blog_tag.name"
rows, err := db.db.Query(statement, id)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return nil, ErrNotFound
}
var status ArticleStatus
var title string
var dateStr string
var modificationDateStr *string
var content string
var tag *string
err = rows.Scan(&status, &title, &dateStr, &modificationDateStr, &content, &tag)
if err != nil {
return nil, err
}
date, _ := time.Parse(time.DateOnly, dateStr)
var modificationDate *time.Time
if modificationDateStr != nil {
tmp, _ := time.Parse(time.DateOnly, *modificationDateStr)
modificationDate = &tmp
}
article := &Article{
ArticleProperties{
id,
title,
status,
make([]string, 0),
date,
modificationDate,
},
content,
nil,
}
if tag != nil {
article.Tags = append(article.Tags, *tag)
}
for rows.Next() {
err = rows.Scan(&status, &title, &dateStr, &modificationDateStr, &content, &tag)
if err != nil {
return nil, err
}
if tag != nil {
article.Tags = append(article.Tags, *tag)
}
}
return article, nil
}
func (db *Database) GetBlogArticleFile(showAll bool, articleId int64, fileId int64) (ArticleFile, error) {
filter := "WHERE blog_file.article_id = ? AND blog_file.id = ?"
if !showAll {
filter = filter + " AND blog_article.status = " + strconv.Itoa(ArticleStatusPublished)
}
statement := "SELECT blog_file.data FROM blog_file" +
" INNER JOIN blog_article ON blog_article.id = blog_file.article_id" +
" " + filter
var data []byte
err := db.db.QueryRow(statement, articleId, fileId).Scan(&data)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ArticleFile{}, ErrNotFound
}
return ArticleFile{}, err
}
return ArticleFile{
fileId,
data,
}, nil
}
func migrate(db *sql.DB, rootPassword string) error {
tx, err := db.Begin()
if err != nil {
return err
}
log.Println("checking database schema version")
var curVersion int
err = tx.QueryRow("SELECT version FROM schema_info").Scan(&curVersion)
if err != nil {
curVersion = -1
}
if curVersion == -1 {
log.Println("database is empty")
err = createV1Tables(tx, rootPassword)
if err != nil {
_ = tx.Rollback()
return err
}
} else {
log.Println("database schema version is", curVersion)
if curVersion != 1 {
return errors.New("unsupported database schema version")
}
}
return tx.Commit()
}
func createV1Tables(tx *sql.Tx, rootPassword string) error {
log.Println("creating tables for schema version 1")
salt := make([]byte, 32)
_, _ = rand.Read(salt)
if rootPassword == "" {
return errors.New("root password is required")
}
rootPasswordHash := sha512.Sum512(append(salt, []byte(rootPassword)...))
_, err := tx.Exec(`
CREATE TABLE schema_info(
id INTEGER PRIMARY KEY CHECK(id = 0),
version INTEGER NOT NULL,
salt BLOB NOT NULL,
root_password BLOB NOT NULL
);
INSERT INTO schema_info(
id,
version,
salt,
root_password
) VALUES (
0,
1,
?,
?
);
CREATE TABLE blog_tag(
id INTEGER PRIMARY KEY,
name TEXT NOT NULL UNIQUE
);
CREATE TABLE blog_article(
id INTEGER PRIMARY KEY,
status INTEGER NOT NULL,
title TEXT NOT NULL UNIQUE,
date TEXT NOT NULL,
modification_date TEXT,
content TEXT NOT NULL
);
CREATE TABLE blog_article_to_tag(
tag_id INTEGER NOT NULL,
article_id INTEGER NOT NULL,
PRIMARY KEY(tag_id, article_id)
);
CREATE TABLE blog_file(
id INTEGER NOT NULL,
article_id INTEGER NOT NULL,
data BLOB NOT NULL,
PRIMARY KEY(id, article_id),
FOREIGN KEY(article_id) REFERENCES blog_article(id)
);
`, salt, rootPasswordHash[:])
if err != nil {
return err
}
return nil
}
func createOrGetTag(tx *sql.Tx, name string) (int64, error) {
var id int64
err := tx.QueryRow("SELECT id FROM blog_tag WHERE name = ?", name).Scan(&id)
if err == nil {
return id, nil
} else if !errors.Is(err, sql.ErrNoRows) {
return -1, err
}
res, err := tx.Exec(
"INSERT INTO blog_tag(name) VALUES (?)",
name,
)
if err != nil {
return -1, err
}
return res.LastInsertId()
}