115 lines
2.6 KiB
Go
115 lines
2.6 KiB
Go
/*
|
|
* owobot - Your server's guardian and entertainer
|
|
* Copyright (C) 2023 owobot Contributors
|
|
*
|
|
* This program is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU Affero General Public License as published by
|
|
* the Free Software Foundation, either version 3 of the License, or
|
|
* (at your option) any later version.
|
|
*
|
|
* This program is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU Affero General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU Affero General Public License
|
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
package db
|
|
|
|
import (
|
|
"context"
|
|
"embed"
|
|
"io"
|
|
"io/fs"
|
|
"path/filepath"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
//go:embed migrations
|
|
var migrations embed.FS
|
|
|
|
var db *sqlx.DB
|
|
|
|
// DB returns the global database instance
|
|
func DB() *sqlx.DB {
|
|
return db
|
|
}
|
|
|
|
// Init opens the database and applies migrations
|
|
func Init(ctx context.Context, dsn string) error {
|
|
g, err := sqlx.Open("sqlite", dsn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
db = g
|
|
return migrate(ctx, db)
|
|
}
|
|
|
|
func Close() error {
|
|
return db.Close()
|
|
}
|
|
|
|
// version returns the current version of the database.
|
|
func version(ctx context.Context, db *sqlx.DB) string {
|
|
var out string
|
|
row := db.QueryRowxContext(ctx, "SELECT current FROM version")
|
|
_ = row.Scan(&out)
|
|
if out == "" {
|
|
out = "0.sql"
|
|
}
|
|
return out
|
|
}
|
|
|
|
// migrate applies database migrations using the embedded sql files.
|
|
func migrate(ctx context.Context, db *sqlx.DB) error {
|
|
current := version(ctx, db)
|
|
return fs.WalkDir(migrations, "migrations", func(path string, d fs.DirEntry, err error) error {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// If the file is a directory, is not an sql file, or is not newer than current,
|
|
// skip it.
|
|
if d.IsDir() || filepath.Ext(path) != ".sql" || d.Name() <= current {
|
|
return nil
|
|
}
|
|
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// Open the sql file containing the migration code
|
|
fl, err := migrations.Open(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer fl.Close()
|
|
|
|
// Read the file
|
|
data, err := io.ReadAll(fl)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Execute the migration
|
|
_, err = tx.ExecContext(ctx, string(data))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Update the version number
|
|
_, err = tx.ExecContext(ctx, "DELETE FROM version; INSERT INTO version VALUES (?)", d.Name())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
})
|
|
}
|