package main import ( "crypto/rand" "crypto/sha512" "database/sql" "errors" "log" "slices" "strconv" "time" _ "github.com/mattn/go-sqlite3" ) type Database struct { db *sql.DB } func NewDatabase(path string, rootPassword string) (*Database, error) { log.Println("Opening database '" + path + "'") db, err := sql.Open("sqlite3", path) if err != nil { return nil, err } err = migrate(db, rootPassword) if err != nil { _ = db.Close() return nil, err } return &Database{db}, nil } func (db *Database) Close() { err := db.db.Close() if err != nil { log.Println("Error closing database:", err.Error()) } } func (db *Database) ValidateRootPassword(password string) (bool, error) { var salt []byte var rootPasswordHash []byte err := db.db.QueryRow("SELECT salt, root_password FROM schema_info").Scan(&salt, &rootPasswordHash) if err != nil { return false, err } passwordHash := sha512.Sum512(append(salt, []byte(password)...)) return slices.Compare(passwordHash[:], rootPasswordHash) == 0, nil } func (db *Database) CreateBlogArticle() (int64, func(*Article) error, error) { tx, err := db.db.Begin() if err != nil { return -1, nil, err } res, err := tx.Exec( "INSERT INTO blog_article(status, title, date, content) VALUES (?, ?, ?, ?)", ArticleStatusDraft, "", time.DateOnly, "", ) if err != nil { return -1, nil, err } id, err := res.LastInsertId() if err != nil { return -1, nil, err } return id, func(article *Article) error { if article == nil { return tx.Rollback() } var modificationDate *string if article.ModificationDate != nil { tmp := article.ModificationDate.Format(time.DateOnly) modificationDate = &tmp } _, err := tx.Exec( "UPDATE blog_article SET title = ?, date = ?, modification_date = ?, content = ? WHERE id = ?", article.Title, article.ReleaseDate.Format(time.DateOnly), modificationDate, article.Content, id, ) if err != nil { _ = tx.Rollback() return err } for _, tag := range article.Tags { tagId, err := createOrGetTag(tx, tag) if err != nil { _ = tx.Rollback() return err } _, err = tx.Exec( "INSERT INTO blog_article_to_tag(tag_id, article_id) VALUES (?, ?)", tagId, id, ) if err != nil { return err } } for _, file := range article.Files { _, err = tx.Exec( "INSERT INTO blog_file(id, articleId, data) VALUES (?, ?, ?)", file.Id, id, file.Data, ) if err != nil { return err } } return tx.Commit() }, err } func (db *Database) GetBlogArticles(showAll bool, offset int, limit int) ([]ArticleProperties, error) { inner := "SELECT id FROM blog_article" if !showAll { inner = inner + " WHERE status = " + strconv.Itoa(ArticleStatusPublished) } inner = inner + " ORDER BY date DESC LIMIT ? OFFSET ?" outer := "SELECT blog_article.id, blog_article.status, blog_article.title, blog_article.date, blog_article.modification_date, blog_tag.name" + " FROM blog_article" + " LEFT JOIN blog_article_to_tag ON blog_article.id = blog_article_to_tag.article_id" + " LEFT JOIN blog_tag ON blog_article_to_tag.tag_id = blog_tag.id" + " WHERE blog_article.id IN (" + inner + ")" + " ORDER BY blog_article.date DESC, blog_article.id" rows, err := db.db.Query( outer, limit, offset, ) if err != nil { return nil, err } articles := make([]ArticleProperties, 0) for rows.Next() { var id int64 var status ArticleStatus var title string var dateStr string var modificationDateStr *string var tag *string err := rows.Scan(&id, &status, &title, &dateStr, &modificationDateStr, &tag) if err != nil { _ = rows.Close() return nil, err } if tag != nil && len(articles) > 0 && articles[len(articles)-1].Id == id { articles[len(articles)-1].Tags = append(articles[len(articles)-1].Tags, *tag) continue } date, _ := time.Parse(time.DateOnly, dateStr) var modificationDate *time.Time if modificationDateStr != nil { tmp, _ := time.Parse(time.DateOnly, *modificationDateStr) modificationDate = &tmp } tags := make([]string, 0) if tag != nil { tags = append(tags, *tag) } articles = append(articles, ArticleProperties{ id, title, status, tags, date, modificationDate, }) } err = rows.Close() if err != nil { return nil, err } return articles, nil } func migrate(db *sql.DB, rootPassword string) error { tx, err := db.Begin() if err != nil { return err } log.Println("Checking database schema version") var curVersion int err = tx.QueryRow("SELECT version FROM schema_info").Scan(&curVersion) if err != nil { curVersion = -1 } if curVersion == -1 { log.Println("Database is empty") err = createV1Tables(tx, rootPassword) if err != nil { _ = tx.Rollback() return err } } else { log.Println("Database schema version is", curVersion) if curVersion != 1 { return errors.New("unsupported database schema version") } } return tx.Commit() } func createV1Tables(tx *sql.Tx, rootPassword string) error { log.Println("Creating tables for schema version 1") salt := make([]byte, 32) _, _ = rand.Read(salt) if rootPassword == "" { return errors.New("root password is required") } rootPasswordHash := sha512.Sum512(append(salt, []byte(rootPassword)...)) _, err := tx.Exec(` CREATE TABLE schema_info( id INTEGER PRIMARY KEY CHECK(id = 0), version INTEGER NOT NULL, salt BLOB NOT NULL, root_password BLOB NOT NULL ); INSERT INTO schema_info( id, version, salt, root_password ) VALUES ( 0, 1, ?, ? ); CREATE TABLE blog_tag( id INTEGER PRIMARY KEY, name TEXT NOT NULL UNIQUE ); CREATE TABLE blog_article( id INTEGER PRIMARY KEY, status INTEGER NOT NULL, title TEXT NOT NULL UNIQUE, date TEXT NOT NULL, modification_date TEXT, content TEXT NOT NULL ); CREATE TABLE blog_article_to_tag( tag_id INTEGER NOT NULL, article_id INTEGER NOT NULL, PRIMARY KEY(tag_id, article_id) ); CREATE TABLE blog_file( id INTEGER NOT NULL, articleId INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(id, articleId), FOREIGN KEY(articleId) REFERENCES blog_article(id) ); `, salt, rootPasswordHash[:]) if err != nil { return err } return nil } func createOrGetTag(tx *sql.Tx, name string) (int64, error) { var id int64 err := tx.QueryRow("SELECT id FROM blog_tag WHERE name = ?", name).Scan(&id) if err == nil { return id, nil } else if !errors.Is(err, sql.ErrNoRows) { return -1, err } res, err := tx.Exec( "INSERT INTO blog_tag(name) VALUES (?)", name, ) if err != nil { return -1, err } return res.LastInsertId() }