996a538704
Signed-off-by: Tobias Erbshäußer <tobias@tesoft.dev>
486 lines
10 KiB
Go
486 lines
10 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha512"
|
|
"database/sql"
|
|
"errors"
|
|
"log"
|
|
"slices"
|
|
"strconv"
|
|
"time"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
var ErrNotFound = errors.New("not found")
|
|
|
|
type Database struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func NewDatabase(path string, rootPassword string) (*Database, error) {
|
|
log.Println("opening database '" + path + "'")
|
|
|
|
db, err := sql.Open("sqlite3", path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = migrate(db, rootPassword)
|
|
if err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
|
|
return &Database{db}, nil
|
|
}
|
|
|
|
func (db *Database) Close() {
|
|
err := db.db.Close()
|
|
if err != nil {
|
|
log.Println("error closing database:", err.Error())
|
|
}
|
|
}
|
|
|
|
func (db *Database) ValidateRootPassword(password string) (bool, error) {
|
|
var salt []byte
|
|
var rootPasswordHash []byte
|
|
err := db.db.QueryRow("SELECT salt, root_password FROM schema_info").Scan(&salt, &rootPasswordHash)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
passwordHash := sha512.Sum512(append(salt, []byte(password)...))
|
|
return slices.Compare(passwordHash[:], rootPasswordHash) == 0, nil
|
|
}
|
|
|
|
func (db *Database) CreateBlogArticle() (int64, func(*Article) error, error) {
|
|
tx, err := db.db.Begin()
|
|
if err != nil {
|
|
return -1, nil, err
|
|
}
|
|
|
|
res, err := tx.Exec(
|
|
"INSERT INTO blog_article(status, title, date, content) VALUES (?, ?, ?, ?)",
|
|
ArticleStatusDraft,
|
|
"",
|
|
time.DateOnly,
|
|
"",
|
|
)
|
|
if err != nil {
|
|
return -1, nil, err
|
|
}
|
|
|
|
id, err := res.LastInsertId()
|
|
if err != nil {
|
|
return -1, nil, err
|
|
}
|
|
|
|
return id, func(article *Article) error {
|
|
if article == nil {
|
|
return tx.Rollback()
|
|
}
|
|
|
|
var modificationDate *string
|
|
if article.ModificationDate != nil {
|
|
tmp := article.ModificationDate.Format(time.DateOnly)
|
|
modificationDate = &tmp
|
|
}
|
|
|
|
_, err := tx.Exec(
|
|
"UPDATE blog_article SET title = ?, date = ?, modification_date = ?, content = ? WHERE id = ?",
|
|
article.Title,
|
|
article.ReleaseDate.Format(time.DateOnly),
|
|
modificationDate,
|
|
article.Content,
|
|
id,
|
|
)
|
|
if err != nil {
|
|
_ = tx.Rollback()
|
|
return err
|
|
}
|
|
|
|
for _, tag := range article.Tags {
|
|
tagId, err := createOrGetTag(tx, tag)
|
|
if err != nil {
|
|
_ = tx.Rollback()
|
|
return err
|
|
}
|
|
|
|
_, err = tx.Exec(
|
|
"INSERT INTO blog_article_to_tag(tag_id, article_id) VALUES (?, ?)",
|
|
tagId,
|
|
id,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
for _, file := range article.Files {
|
|
_, err = tx.Exec(
|
|
"INSERT INTO blog_file(id, article_id, data) VALUES (?, ?, ?)",
|
|
file.Id,
|
|
id,
|
|
file.Data,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}, err
|
|
}
|
|
|
|
func (db *Database) GetBlogArticles(showAll bool, offset int, limit int) ([]ArticleProperties, error) {
|
|
inner := "SELECT id FROM blog_article"
|
|
if !showAll {
|
|
inner = inner + " WHERE status = " + strconv.Itoa(ArticleStatusPublished)
|
|
}
|
|
inner = inner + " ORDER BY date DESC LIMIT ? OFFSET ?"
|
|
outer := "SELECT blog_article.id, blog_article.status, blog_article.title, blog_article.date, blog_article.modification_date, blog_tag.name" +
|
|
" FROM blog_article" +
|
|
" LEFT JOIN blog_article_to_tag ON blog_article.id = blog_article_to_tag.article_id" +
|
|
" LEFT JOIN blog_tag ON blog_article_to_tag.tag_id = blog_tag.id" +
|
|
" WHERE blog_article.id IN (" + inner + ")" +
|
|
" ORDER BY blog_article.date DESC, blog_article.id"
|
|
|
|
rows, err := db.db.Query(
|
|
outer,
|
|
limit,
|
|
offset,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
articles := make([]ArticleProperties, 0)
|
|
for rows.Next() {
|
|
var id int64
|
|
var status ArticleStatus
|
|
var title string
|
|
var dateStr string
|
|
var modificationDateStr *string
|
|
var tag *string
|
|
err := rows.Scan(&id, &status, &title, &dateStr, &modificationDateStr, &tag)
|
|
if err != nil {
|
|
_ = rows.Close()
|
|
return nil, err
|
|
}
|
|
|
|
if tag != nil && len(articles) > 0 && articles[len(articles)-1].Id == id {
|
|
articles[len(articles)-1].Tags = append(articles[len(articles)-1].Tags, *tag)
|
|
continue
|
|
}
|
|
|
|
date, _ := time.Parse(time.DateOnly, dateStr)
|
|
var modificationDate *time.Time
|
|
if modificationDateStr != nil {
|
|
tmp, _ := time.Parse(time.DateOnly, *modificationDateStr)
|
|
modificationDate = &tmp
|
|
}
|
|
|
|
tags := make([]string, 0)
|
|
if tag != nil {
|
|
tags = append(tags, *tag)
|
|
}
|
|
|
|
articles = append(articles, ArticleProperties{
|
|
id,
|
|
title,
|
|
status,
|
|
tags,
|
|
date,
|
|
modificationDate,
|
|
})
|
|
}
|
|
|
|
err = rows.Close()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return articles, nil
|
|
}
|
|
|
|
func (db *Database) GetBlogArticle(showAll bool, id int64) (*Article, error) {
|
|
filter := "WHERE blog_article.id = ?"
|
|
if !showAll {
|
|
filter = filter + " AND status = " + strconv.Itoa(ArticleStatusPublished)
|
|
}
|
|
|
|
statement := "SELECT blog_article.status, blog_article.title, blog_article.date, blog_article.modification_date, blog_article.content, blog_tag.name" +
|
|
" FROM blog_article" +
|
|
" LEFT JOIN blog_article_to_tag ON blog_article.id = blog_article_to_tag.article_id" +
|
|
" LEFT JOIN blog_tag ON blog_article_to_tag.tag_id = blog_tag.id" +
|
|
" " + filter +
|
|
" ORDER BY blog_tag.name"
|
|
|
|
rows, err := db.db.Query(statement, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
if !rows.Next() {
|
|
return nil, ErrNotFound
|
|
}
|
|
|
|
var status ArticleStatus
|
|
var title string
|
|
var dateStr string
|
|
var modificationDateStr *string
|
|
var content string
|
|
var tag *string
|
|
|
|
err = rows.Scan(&status, &title, &dateStr, &modificationDateStr, &content, &tag)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
date, _ := time.Parse(time.DateOnly, dateStr)
|
|
var modificationDate *time.Time
|
|
if modificationDateStr != nil {
|
|
tmp, _ := time.Parse(time.DateOnly, *modificationDateStr)
|
|
modificationDate = &tmp
|
|
}
|
|
|
|
article := &Article{
|
|
ArticleProperties{
|
|
id,
|
|
title,
|
|
status,
|
|
make([]string, 0),
|
|
date,
|
|
modificationDate,
|
|
},
|
|
content,
|
|
nil,
|
|
}
|
|
|
|
if tag != nil {
|
|
article.Tags = append(article.Tags, *tag)
|
|
}
|
|
|
|
for rows.Next() {
|
|
err = rows.Scan(&status, &title, &dateStr, &modificationDateStr, &content, &tag)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if tag != nil {
|
|
article.Tags = append(article.Tags, *tag)
|
|
}
|
|
}
|
|
|
|
return article, nil
|
|
}
|
|
|
|
func (db *Database) GetBlogArticleFile(showAll bool, articleId int64, fileId int64) (ArticleFile, error) {
|
|
filter := "WHERE blog_file.article_id = ? AND blog_file.id = ?"
|
|
if !showAll {
|
|
filter = filter + " AND blog_article.status = " + strconv.Itoa(ArticleStatusPublished)
|
|
}
|
|
|
|
statement := "SELECT blog_file.data FROM blog_file" +
|
|
" INNER JOIN blog_article ON blog_article.id = blog_file.article_id" +
|
|
" " + filter
|
|
|
|
var data []byte
|
|
err := db.db.QueryRow(statement, articleId, fileId).Scan(&data)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return ArticleFile{}, ErrNotFound
|
|
}
|
|
|
|
return ArticleFile{}, err
|
|
}
|
|
|
|
return ArticleFile{
|
|
fileId,
|
|
data,
|
|
}, nil
|
|
}
|
|
|
|
func (db *Database) SetBlogArticleStatus(id int64, status ArticleStatus) error {
|
|
res, err := db.db.Exec("UPDATE blog_article SET status = ? WHERE id = ?", status, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if affected == 0 {
|
|
return ErrNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (db *Database) GetBlogTags(showAll bool, offset int, limit int) ([]string, int64, error) {
|
|
filter := ""
|
|
filterArgs := make([]interface{}, 0)
|
|
|
|
if !showAll {
|
|
filter = " WHERE blog_article.status = ?"
|
|
filterArgs = append(filterArgs, ArticleStatusPublished)
|
|
}
|
|
|
|
args := make([]interface{}, 0, len(filterArgs)*2+2)
|
|
args = append(args, filterArgs...)
|
|
args = append(args, filterArgs...)
|
|
args = append(args, limit)
|
|
args = append(args, offset)
|
|
|
|
joins := " INNER JOIN blog_article_to_tag ON blog_article_to_tag.tag_id = blog_tag.id" +
|
|
" INNER JOIN blog_article ON blog_article.id = blog_article_to_tag.article_id"
|
|
|
|
rows, err := db.db.Query(
|
|
"SELECT blog_tag.name, (SELECT COUNT(*) FROM blog_tag"+joins+filter+") FROM blog_tag"+joins+filter+" LIMIT ? OFFSET ?",
|
|
args...,
|
|
)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
tags := make([]string, 0)
|
|
total := int64(0)
|
|
|
|
for rows.Next() {
|
|
var name string
|
|
err := rows.Scan(&name, &total)
|
|
if err != nil {
|
|
_ = rows.Close()
|
|
return nil, 0, err
|
|
}
|
|
|
|
tags = append(tags, name)
|
|
}
|
|
|
|
err = rows.Close()
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
return tags, total, nil
|
|
}
|
|
|
|
func migrate(db *sql.DB, rootPassword string) error {
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Println("checking database schema version")
|
|
var curVersion int
|
|
err = tx.QueryRow("SELECT version FROM schema_info").Scan(&curVersion)
|
|
if err != nil {
|
|
curVersion = -1
|
|
}
|
|
|
|
if curVersion == -1 {
|
|
log.Println("database is empty")
|
|
|
|
err = createV1Tables(tx, rootPassword)
|
|
if err != nil {
|
|
_ = tx.Rollback()
|
|
return err
|
|
}
|
|
} else {
|
|
log.Println("database schema version is", curVersion)
|
|
|
|
if curVersion != 1 {
|
|
return errors.New("unsupported database schema version")
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func createV1Tables(tx *sql.Tx, rootPassword string) error {
|
|
log.Println("creating tables for schema version 1")
|
|
|
|
salt := make([]byte, 32)
|
|
_, _ = rand.Read(salt)
|
|
|
|
if rootPassword == "" {
|
|
return errors.New("root password is required")
|
|
}
|
|
rootPasswordHash := sha512.Sum512(append(salt, []byte(rootPassword)...))
|
|
|
|
_, err := tx.Exec(`
|
|
CREATE TABLE schema_info(
|
|
id INTEGER PRIMARY KEY CHECK(id = 0),
|
|
version INTEGER NOT NULL,
|
|
salt BLOB NOT NULL,
|
|
root_password BLOB NOT NULL
|
|
);
|
|
INSERT INTO schema_info(
|
|
id,
|
|
version,
|
|
salt,
|
|
root_password
|
|
) VALUES (
|
|
0,
|
|
1,
|
|
?,
|
|
?
|
|
);
|
|
|
|
CREATE TABLE blog_tag(
|
|
id INTEGER PRIMARY KEY,
|
|
name TEXT NOT NULL UNIQUE
|
|
);
|
|
CREATE TABLE blog_article(
|
|
id INTEGER PRIMARY KEY,
|
|
status INTEGER NOT NULL,
|
|
title TEXT NOT NULL UNIQUE,
|
|
date TEXT NOT NULL,
|
|
modification_date TEXT,
|
|
content TEXT NOT NULL
|
|
);
|
|
CREATE TABLE blog_article_to_tag(
|
|
tag_id INTEGER NOT NULL,
|
|
article_id INTEGER NOT NULL,
|
|
PRIMARY KEY(tag_id, article_id)
|
|
);
|
|
CREATE TABLE blog_file(
|
|
id INTEGER NOT NULL,
|
|
article_id INTEGER NOT NULL,
|
|
data BLOB NOT NULL,
|
|
PRIMARY KEY(id, article_id),
|
|
FOREIGN KEY(article_id) REFERENCES blog_article(id)
|
|
);
|
|
`, salt, rootPasswordHash[:])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func createOrGetTag(tx *sql.Tx, name string) (int64, error) {
|
|
var id int64
|
|
err := tx.QueryRow("SELECT id FROM blog_tag WHERE name = ?", name).Scan(&id)
|
|
if err == nil {
|
|
return id, nil
|
|
} else if !errors.Is(err, sql.ErrNoRows) {
|
|
return -1, err
|
|
}
|
|
|
|
res, err := tx.Exec(
|
|
"INSERT INTO blog_tag(name) VALUES (?)",
|
|
name,
|
|
)
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
|
|
return res.LastInsertId()
|
|
}
|