1
0
Fork 0
mirror of https://github.com/thomiceli/opengist.git synced 2025-01-24 07:00:32 +00:00
opengist/internal/db/db.go

263 lines
5.6 KiB
Go
Raw Permalink Normal View History

2023-09-03 00:30:57 +02:00
package db
2023-03-14 16:22:52 +01:00
import (
"errors"
"fmt"
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm/logger"
"net/url"
"path/filepath"
"slices"
"strings"
"time"
2023-06-09 15:25:41 +02:00
"github.com/rs/zerolog/log"
"github.com/thomiceli/opengist/internal/config"
2023-03-14 16:22:52 +01:00
"gorm.io/gorm"
)
var db *gorm.DB
const (
SQLite databaseType = iota
PostgreSQL
MySQL
)
2023-03-14 16:22:52 +01:00
type databaseType int
func (d databaseType) String() string {
return [...]string{"SQLite", "PostgreSQL", "MySQL"}[d]
}
type databaseInfo struct {
Type databaseType
Host string
Port string
User string
Password string
Database string
}
var DatabaseInfo *databaseInfo
func parseDBURI(uri string) (*databaseInfo, error) {
info := &databaseInfo{}
2025-01-20 01:57:39 +01:00
if uri == ":memory:" {
info.Type = SQLite
info.Database = uri
return info, nil
}
u, err := url.Parse(uri)
if err != nil {
return nil, fmt.Errorf("invalid URI: %v", err)
2023-06-09 15:25:41 +02:00
}
if u.Scheme == "" {
info.Type = SQLite
info.Database = filepath.Join(config.GetHomeDir(), uri)
return info, nil
}
switch u.Scheme {
case "postgres", "postgresql":
info.Type = PostgreSQL
case "mysql", "mariadb":
info.Type = MySQL
case "file":
info.Type = SQLite
default:
return nil, fmt.Errorf("unknown database: %v", err)
}
if u.Host != "" {
host, port, _ := strings.Cut(u.Host, ":")
info.Host = host
info.Port = port
}
if u.User != nil {
info.User = u.User.Username()
info.Password, _ = u.User.Password()
}
switch info.Type {
case PostgreSQL, MySQL:
info.Database = strings.TrimPrefix(u.Path, "/")
case SQLite:
info.Database = u.String()
default:
return nil, fmt.Errorf("unknown database: %v", err)
2023-09-17 00:59:47 +02:00
}
return info, nil
}
2025-01-20 01:57:39 +01:00
func Setup(dbUri string) error {
dbInfo, err := parseDBURI(dbUri)
if err != nil {
2023-03-14 16:22:52 +01:00
return err
}
log.Info().Msgf("Setting up a %s database connection", dbInfo.Type)
2025-01-20 01:57:39 +01:00
var setupFunc func(databaseInfo) error
switch dbInfo.Type {
case SQLite:
setupFunc = setupSQLite
case PostgreSQL:
setupFunc = setupPostgres
case MySQL:
setupFunc = setupMySQL
default:
return fmt.Errorf("unknown database type: %v", dbInfo.Type)
}
maxAttempts := 60
retryInterval := 1 * time.Second
for attempt := 1; attempt <= maxAttempts; attempt++ {
2025-01-20 01:57:39 +01:00
err = setupFunc(*dbInfo)
if err == nil {
log.Info().Msg("Database connection established")
break
}
if attempt < maxAttempts {
log.Warn().Err(err).Msgf("Failed to connect to database (attempt %d), retrying in %v...", attempt, retryInterval)
time.Sleep(retryInterval)
} else {
return err
}
}
DatabaseInfo = dbInfo
2023-06-21 18:19:17 +02:00
if err = db.SetupJoinTable(&Gist{}, "Likes", &Like{}); err != nil {
return err
}
if err = db.SetupJoinTable(&User{}, "Liked", &Like{}); err != nil {
return err
}
2024-10-24 23:23:00 +02:00
if err = db.AutoMigrate(&User{}, &Gist{}, &SSHKey{}, &AdminSetting{}, &Invitation{}, &WebAuthnCredential{}, &TOTP{}); err != nil {
2023-03-14 16:22:52 +01:00
return err
}
2025-01-20 01:57:39 +01:00
if err = applyMigrations(dbInfo); err != nil {
2023-09-17 02:55:17 +02:00
return err
}
// Default admin setting values
return initAdminSettings(map[string]string{
SettingDisableSignup: "0",
SettingRequireLogin: "0",
SettingAllowGistsWithoutLogin: "0",
SettingDisableLoginForm: "0",
SettingDisableGravatar: "0",
})
2023-03-14 16:22:52 +01:00
}
2023-09-17 00:59:47 +02:00
func Close() error {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
2023-03-14 16:22:52 +01:00
func CountAll(table interface{}) (int64, error) {
var count int64
err := db.Model(table).Count(&count).Error
return count, err
}
func IsUniqueConstraintViolation(err error) bool {
return errors.Is(err, gorm.ErrDuplicatedKey)
}
2023-12-16 01:27:00 +01:00
func Ping() error {
sql, err := db.DB()
if err != nil {
return err
}
return sql.Ping()
}
2025-01-20 01:57:39 +01:00
func setupSQLite(dbInfo databaseInfo) error {
var err error
2025-01-20 01:57:39 +01:00
var dsn string
journalMode := strings.ToUpper(config.C.SqliteJournalMode)
if !slices.Contains([]string{"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"}, journalMode) {
log.Warn().Msg("Invalid SQLite journal mode: " + journalMode)
}
2025-01-20 01:57:39 +01:00
if dbInfo.Database == ":memory:" {
dsn = ":memory:?_fk=true&cache=shared"
} else {
u, err := url.Parse(dbInfo.Database)
if err != nil {
return err
}
2025-01-20 01:57:39 +01:00
u.Scheme = "file"
q := u.Query()
q.Set("_fk", "true")
q.Set("_journal_mode", journalMode)
u.RawQuery = q.Encode()
dsn = u.String()
}
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
TranslateError: true,
})
return err
}
2025-01-20 01:57:39 +01:00
func setupPostgres(dbInfo databaseInfo) error {
var err error
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", dbInfo.Host, dbInfo.Port, dbInfo.User, dbInfo.Password, dbInfo.Database)
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
TranslateError: true,
})
return err
}
2025-01-20 01:57:39 +01:00
func setupMySQL(dbInfo databaseInfo) error {
var err error
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", dbInfo.User, dbInfo.Password, dbInfo.Host, dbInfo.Port, dbInfo.Database)
db, err = gorm.Open(mysql.New(mysql.Config{
DSN: dsn,
DontSupportRenameIndex: true,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
TranslateError: true,
})
return err
}
func DeprecationDBFilename() {
if config.C.DBFilename != "" {
log.Warn().Msg("The 'db-filename'/'OG_DB_FILENAME' configuration option is deprecated and will be removed in a future version. Please use 'db-uri'/'OG_DB_URI' instead.")
}
if config.C.DBUri == "" {
config.C.DBUri = config.C.DBFilename
}
}
func TruncateDatabase() error {
2024-10-24 23:23:00 +02:00
return db.Migrator().DropTable("likes", &User{}, "gists", &SSHKey{}, &AdminSetting{}, &Invitation{}, &WebAuthnCredential{}, &TOTP{})
}