2023-09-03 00:30:57 +02:00
package db
2023-03-14 16:22:52 +01:00
import (
2023-04-17 14:25:39 +02:00
"errors"
2024-09-20 16:01:09 +02:00
"fmt"
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm/logger"
"net/url"
"path/filepath"
2024-01-04 05:11:46 +01:00
"slices"
2023-06-16 11:08:33 -05:00
"strings"
2024-09-20 16:01:09 +02:00
"time"
2023-06-16 11:08:33 -05:00
2023-06-09 15:25:41 +02:00
"github.com/rs/zerolog/log"
"github.com/thomiceli/opengist/internal/config"
2023-03-14 16:22:52 +01:00
"gorm.io/gorm"
)
var db * gorm . DB
2024-09-20 16:01:09 +02:00
const (
SQLite databaseType = iota
PostgreSQL
MySQL
)
2023-03-14 16:22:52 +01:00
2024-09-20 16:01:09 +02:00
type databaseType int
func ( d databaseType ) String ( ) string {
return [ ... ] string { "SQLite" , "PostgreSQL" , "MySQL" } [ d ]
}
type databaseInfo struct {
Type databaseType
Host string
Port string
User string
Password string
Database string
}
var DatabaseInfo * databaseInfo
func parseDBURI ( uri string ) ( * databaseInfo , error ) {
info := & databaseInfo { }
2025-01-20 01:57:39 +01:00
if uri == ":memory:" {
info . Type = SQLite
info . Database = uri
return info , nil
}
2024-09-20 16:01:09 +02:00
u , err := url . Parse ( uri )
if err != nil {
return nil , fmt . Errorf ( "invalid URI: %v" , err )
2023-06-09 15:25:41 +02:00
}
2024-11-25 22:07:13 +01:00
if u . Scheme == "" {
info . Type = SQLite
info . Database = filepath . Join ( config . GetHomeDir ( ) , uri )
return info , nil
}
2024-09-20 16:01:09 +02:00
switch u . Scheme {
case "postgres" , "postgresql" :
info . Type = PostgreSQL
case "mysql" , "mariadb" :
info . Type = MySQL
2024-11-25 22:07:13 +01:00
case "file" :
info . Type = SQLite
2024-09-20 16:01:09 +02:00
default :
return nil , fmt . Errorf ( "unknown database: %v" , err )
}
if u . Host != "" {
host , port , _ := strings . Cut ( u . Host , ":" )
info . Host = host
info . Port = port
}
if u . User != nil {
info . User = u . User . Username ( )
info . Password , _ = u . User . Password ( )
}
switch info . Type {
case PostgreSQL , MySQL :
info . Database = strings . TrimPrefix ( u . Path , "/" )
2024-11-25 22:07:13 +01:00
case SQLite :
info . Database = u . String ( )
2024-09-20 16:01:09 +02:00
default :
return nil , fmt . Errorf ( "unknown database: %v" , err )
2023-09-17 00:59:47 +02:00
}
2024-09-20 16:01:09 +02:00
return info , nil
}
2025-01-20 01:57:39 +01:00
func Setup ( dbUri string ) error {
2024-09-20 16:01:09 +02:00
dbInfo , err := parseDBURI ( dbUri )
if err != nil {
2023-03-14 16:22:52 +01:00
return err
}
2024-09-20 16:01:09 +02:00
log . Info ( ) . Msgf ( "Setting up a %s database connection" , dbInfo . Type )
2025-01-20 01:57:39 +01:00
var setupFunc func ( databaseInfo ) error
2024-09-20 16:01:09 +02:00
switch dbInfo . Type {
case SQLite :
setupFunc = setupSQLite
case PostgreSQL :
setupFunc = setupPostgres
case MySQL :
setupFunc = setupMySQL
default :
return fmt . Errorf ( "unknown database type: %v" , dbInfo . Type )
}
maxAttempts := 60
retryInterval := 1 * time . Second
for attempt := 1 ; attempt <= maxAttempts ; attempt ++ {
2025-01-20 01:57:39 +01:00
err = setupFunc ( * dbInfo )
2024-09-20 16:01:09 +02:00
if err == nil {
log . Info ( ) . Msg ( "Database connection established" )
break
}
if attempt < maxAttempts {
log . Warn ( ) . Err ( err ) . Msgf ( "Failed to connect to database (attempt %d), retrying in %v..." , attempt , retryInterval )
time . Sleep ( retryInterval )
} else {
return err
}
}
DatabaseInfo = dbInfo
2023-06-21 18:19:17 +02:00
if err = db . SetupJoinTable ( & Gist { } , "Likes" , & Like { } ) ; err != nil {
return err
}
if err = db . SetupJoinTable ( & User { } , "Liked" , & Like { } ) ; err != nil {
return err
}
2024-10-24 23:23:00 +02:00
if err = db . AutoMigrate ( & User { } , & Gist { } , & SSHKey { } , & AdminSetting { } , & Invitation { } , & WebAuthnCredential { } , & TOTP { } ) ; err != nil {
2023-03-14 16:22:52 +01:00
return err
}
2025-01-20 01:57:39 +01:00
if err = applyMigrations ( dbInfo ) ; err != nil {
2023-09-17 02:55:17 +02:00
return err
}
2023-04-18 02:33:19 +02:00
2023-04-17 00:17:06 +02:00
// Default admin setting values
return initAdminSettings ( map [ string ] string {
2024-05-12 14:40:11 -07:00
SettingDisableSignup : "0" ,
SettingRequireLogin : "0" ,
SettingAllowGistsWithoutLogin : "0" ,
SettingDisableLoginForm : "0" ,
SettingDisableGravatar : "0" ,
2023-04-17 00:17:06 +02:00
} )
2023-03-14 16:22:52 +01:00
}
2023-09-17 00:59:47 +02:00
func Close ( ) error {
sqlDB , err := db . DB ( )
if err != nil {
return err
}
return sqlDB . Close ( )
}
2023-03-14 16:22:52 +01:00
func CountAll ( table interface { } ) ( int64 , error ) {
var count int64
err := db . Model ( table ) . Count ( & count ) . Error
return count , err
}
2023-04-17 14:25:39 +02:00
func IsUniqueConstraintViolation ( err error ) bool {
2024-09-20 16:01:09 +02:00
return errors . Is ( err , gorm . ErrDuplicatedKey )
2023-04-17 14:25:39 +02:00
}
2023-12-16 01:27:00 +01:00
func Ping ( ) error {
sql , err := db . DB ( )
if err != nil {
return err
}
return sql . Ping ( )
}
2024-09-20 16:01:09 +02:00
2025-01-20 01:57:39 +01:00
func setupSQLite ( dbInfo databaseInfo ) error {
2024-09-20 16:01:09 +02:00
var err error
2025-01-20 01:57:39 +01:00
var dsn string
2024-09-20 16:01:09 +02:00
journalMode := strings . ToUpper ( config . C . SqliteJournalMode )
if ! slices . Contains ( [ ] string { "DELETE" , "TRUNCATE" , "PERSIST" , "MEMORY" , "WAL" , "OFF" } , journalMode ) {
log . Warn ( ) . Msg ( "Invalid SQLite journal mode: " + journalMode )
}
2025-01-20 01:57:39 +01:00
if dbInfo . Database == ":memory:" {
dsn = ":memory:?_fk=true&cache=shared"
} else {
u , err := url . Parse ( dbInfo . Database )
if err != nil {
return err
}
2024-09-20 16:01:09 +02:00
2025-01-20 01:57:39 +01:00
u . Scheme = "file"
q := u . Query ( )
q . Set ( "_fk" , "true" )
q . Set ( "_journal_mode" , journalMode )
u . RawQuery = q . Encode ( )
dsn = u . String ( )
2024-11-25 22:07:13 +01:00
}
db , err = gorm . Open ( sqlite . Open ( dsn ) , & gorm . Config {
2024-09-20 16:01:09 +02:00
Logger : logger . Default . LogMode ( logger . Silent ) ,
TranslateError : true ,
} )
return err
}
2025-01-20 01:57:39 +01:00
func setupPostgres ( dbInfo databaseInfo ) error {
2024-09-20 16:01:09 +02:00
var err error
dsn := fmt . Sprintf ( "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable" , dbInfo . Host , dbInfo . Port , dbInfo . User , dbInfo . Password , dbInfo . Database )
db , err = gorm . Open ( postgres . Open ( dsn ) , & gorm . Config {
Logger : logger . Default . LogMode ( logger . Silent ) ,
TranslateError : true ,
} )
return err
}
2025-01-20 01:57:39 +01:00
func setupMySQL ( dbInfo databaseInfo ) error {
2024-09-20 16:01:09 +02:00
var err error
dsn := fmt . Sprintf ( "%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local" , dbInfo . User , dbInfo . Password , dbInfo . Host , dbInfo . Port , dbInfo . Database )
db , err = gorm . Open ( mysql . New ( mysql . Config {
DSN : dsn ,
DontSupportRenameIndex : true ,
} ) , & gorm . Config {
Logger : logger . Default . LogMode ( logger . Silent ) ,
TranslateError : true ,
} )
return err
}
func DeprecationDBFilename ( ) {
if config . C . DBFilename != "" {
log . Warn ( ) . Msg ( "The 'db-filename'/'OG_DB_FILENAME' configuration option is deprecated and will be removed in a future version. Please use 'db-uri'/'OG_DB_URI' instead." )
}
if config . C . DBUri == "" {
config . C . DBUri = config . C . DBFilename
}
}
func TruncateDatabase ( ) error {
2024-10-24 23:23:00 +02:00
return db . Migrator ( ) . DropTable ( "likes" , & User { } , "gists" , & SSHKey { } , & AdminSetting { } , & Invitation { } , & WebAuthnCredential { } , & TOTP { } )
2024-09-20 16:01:09 +02:00
}