diff --git a/internal/auth/oauth/gitea.go b/internal/auth/oauth/gitea.go new file mode 100644 index 0000000..48af2ca --- /dev/null +++ b/internal/auth/oauth/gitea.go @@ -0,0 +1,33 @@ +package oauth + +import ( + "github.com/markbates/goth" + "github.com/markbates/goth/providers/gitea" + "github.com/thomiceli/opengist/internal/config" +) + +type GiteaProvider struct { + Provider + URL string +} + +func (p *GiteaProvider) RegisterProvider() error { + goth.UseProviders( + gitea.NewCustomisedURL( + config.C.GiteaClientKey, + config.C.GiteaSecret, + urlJoin(p.URL, "/oauth/gitea/callback"), + urlJoin(config.C.GiteaUrl, "/login/oauth/authorize"), + urlJoin(config.C.GiteaUrl, "/login/oauth/access_token"), + urlJoin(config.C.GiteaUrl, "/api/v1/user"), + ), + ) + + return nil +} + +func NewGiteaProvider(url string) *GiteaProvider { + return &GiteaProvider{ + URL: url, + } +} diff --git a/internal/auth/oauth/github.go b/internal/auth/oauth/github.go new file mode 100644 index 0000000..8b68046 --- /dev/null +++ b/internal/auth/oauth/github.go @@ -0,0 +1,39 @@ +package oauth + +import ( + "github.com/markbates/goth" + "github.com/markbates/goth/providers/github" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/db" +) + +type GitHubProvider struct { + Provider + URL string +} + +func (p *GitHubProvider) RegisterProvider() error { + goth.UseProviders( + github.New( + config.C.GithubClientKey, + config.C.GithubSecret, + urlJoin(p.URL, "/oauth/github/callback"), + ), + ) + + return nil +} + +func (p *GitHubProvider) SSHKeysURL(user *db.User) string { + +} + +func (p *GitHubProvider) GetProviderUserID(user *db.User) bool { + return user.GithubID != "" +} + +func NewGitHubProvider(url string) *GitHubProvider { + return &GitHubProvider{ + URL: url, + } +} diff --git a/internal/auth/oauth/gitlab.go b/internal/auth/oauth/gitlab.go new file mode 100644 index 0000000..12e8c43 --- /dev/null +++ b/internal/auth/oauth/gitlab.go @@ -0,0 +1,42 @@ +package oauth + +import ( + "github.com/markbates/goth" + "github.com/markbates/goth/providers/gitlab" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/db" +) + +type GitLabProvider struct { + Provider + URL string +} + +func (p *GitLabProvider) RegisterProvider() error { + goth.UseProviders( + gitlab.NewCustomisedURL( + config.C.GitlabClientKey, + config.C.GitlabSecret, + urlJoin(p.URL, "/oauth/gitlab/callback"), + urlJoin(config.C.GitlabUrl, "/oauth/authorize"), + urlJoin(config.C.GitlabUrl, "/oauth/token"), + urlJoin(config.C.GitlabUrl, "/api/v4/user"), + ), + ) + + return nil +} + +func (p *GitLabProvider) SSHKeysURL(user *db.User) string { + +} + +func (p *GitLabProvider) GetProviderUserID(user *db.User) bool { + return user.GitlabID != "" +} + +func NewGitLabProvider(url string) *GitLabProvider { + return &GitLabProvider{ + URL: url, + } +} diff --git a/internal/auth/oauth/openid.go b/internal/auth/oauth/openid.go new file mode 100644 index 0000000..0e0949f --- /dev/null +++ b/internal/auth/oauth/openid.go @@ -0,0 +1,38 @@ +package oauth + +import ( + "errors" + "github.com/markbates/goth" + "github.com/markbates/goth/providers/openidConnect" + "github.com/thomiceli/opengist/internal/config" +) + +type OIDCProvider struct { + Provider + URL string +} + +func (p *OIDCProvider) RegisterProvider() error { + oidcProvider, err := openidConnect.New( + config.C.OIDCClientKey, + config.C.OIDCSecret, + urlJoin(p.URL, "/oauth/openid-connect/callback"), + config.C.OIDCDiscoveryUrl, + "openid", + "email", + "profile", + ) + + if err != nil { + return errors.New("Cannot create OIDC provider: " + err.Error()) + } + + goth.UseProviders(oidcProvider) + return nil +} + +func NewOIDCProvider(url string) *OIDCProvider { + return &OIDCProvider{ + URL: url, + } +} diff --git a/internal/auth/oauth/provider.go b/internal/auth/oauth/provider.go new file mode 100644 index 0000000..bd19bf3 --- /dev/null +++ b/internal/auth/oauth/provider.go @@ -0,0 +1,46 @@ +package oauth + +import ( + "github.com/rs/zerolog/log" + "github.com/thomiceli/opengist/internal/db" + "net/url" +) + +const ( + GitHubProviderString = "github" + GitLabProviderString = "gitlab" + GiteaProviderString = "gitea" + OpenIDConnectString = "openid-connect" +) + +type Provider interface { + RegisterProvider() error + SSHKeysURL(user *db.User) string + GetProviderUserID(user *db.User) bool + SetProviderUserID() error + GetAvatarURL(user *db.User) string +} + +func DefineProvider(provider string, url string) Provider { + switch provider { + case GitHubProviderString: + return NewGitHubProvider(url) + case GitLabProviderString: + return NewGitLabProvider(url) + case GiteaProviderString: + return NewGiteaProvider(url) + case OpenIDConnectString: + return NewOIDCProvider(url) + } + + return nil +} + +func urlJoin(base string, elem ...string) string { + joined, err := url.JoinPath(base, elem...) + if err != nil { + log.Error().Err(err).Msg("Cannot join url") + } + + return joined +} diff --git a/internal/web/handlers/auth.go b/internal/web/handlers/auth.go index 84a1259..24314fd 100644 --- a/internal/web/handlers/auth.go +++ b/internal/web/handlers/auth.go @@ -1,807 +1,9 @@ package handlers import ( - "bytes" - gocontext "context" - "crypto/md5" - gojson "encoding/json" - "errors" - "fmt" - "github.com/markbates/goth" - "github.com/markbates/goth/gothic" - "github.com/markbates/goth/providers/gitea" - "github.com/markbates/goth/providers/github" - "github.com/markbates/goth/providers/gitlab" - "github.com/markbates/goth/providers/openidConnect" - "github.com/rs/zerolog/log" - "github.com/thomiceli/opengist/internal/auth/totp" - "github.com/thomiceli/opengist/internal/auth/webauthn" - "github.com/thomiceli/opengist/internal/config" - "github.com/thomiceli/opengist/internal/db" - "github.com/thomiceli/opengist/internal/i18n" - "github.com/thomiceli/opengist/internal/utils" "github.com/thomiceli/opengist/internal/web/context" - "golang.org/x/text/cases" - "golang.org/x/text/language" - "gorm.io/gorm" - "io" - "net/http" - "net/url" - "strings" ) -const ( - GitHubProvider = "github" - GitLabProvider = "gitlab" - GiteaProvider = "gitea" - OpenIDConnect = "openid-connect" -) - -func Register(ctx *context.Context) error { - disableSignup := ctx.GetData("DisableSignup") - disableForm := ctx.GetData("DisableLoginForm") - - code := ctx.QueryParam("code") - if code != "" { - if invitation, err := db.GetInvitationByCode(code); err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return ctx.ErrorRes(500, "Cannot check for invitation code", err) - } else if invitation != nil && invitation.IsUsable() { - disableSignup = false - } - } - - ctx.SetData("title", ctx.TrH("auth.new-account")) - ctx.SetData("htmlTitle", ctx.TrH("auth.new-account")) - ctx.SetData("disableForm", disableForm) - ctx.SetData("disableSignup", disableSignup) - ctx.SetData("isLoginPage", false) - return ctx.HTML_("auth_form.html") -} - -func ProcessRegister(ctx *context.Context) error { - disableSignup := ctx.GetData("DisableSignup") - - code := ctx.QueryParam("code") - invitation, err := db.GetInvitationByCode(code) - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return ctx.ErrorRes(500, "Cannot check for invitation code", err) - } else if invitation.ID != 0 && invitation.IsUsable() { - disableSignup = false - } - - if disableSignup == true { - return ctx.ErrorRes(403, ctx.Tr("error.signup-disabled"), nil) - } - - if ctx.GetData("DisableLoginForm") == true { - return ctx.ErrorRes(403, ctx.Tr("error.signup-disabled-form"), nil) - } - - ctx.SetData("title", ctx.TrH("auth.new-account")) - ctx.SetData("htmlTitle", ctx.TrH("auth.new-account")) - - sess := ctx.GetSession() - - dto := new(db.UserDTO) - if err := ctx.Bind(dto); err != nil { - return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) - } - - if err := ctx.Validate(dto); err != nil { - ctx.AddFlash(utils.ValidationMessages(&err, ctx.GetData("locale").(*i18n.Locale)), "error") - return ctx.HTML_("auth_form.html") - } - - if exists, err := db.UserExists(dto.Username); err != nil || exists { - ctx.AddFlash(ctx.Tr("flash.auth.username-exists"), "error") - return ctx.HTML_("auth_form.html") - } - - user := dto.ToUser() - - password, err := utils.Argon2id.Hash(user.Password) - if err != nil { - return ctx.ErrorRes(500, "Cannot hash password", err) - } - user.Password = password - - if err = user.Create(); err != nil { - return ctx.ErrorRes(500, "Cannot create user", err) - } - - if user.ID == 1 { - if err = user.SetAdmin(); err != nil { - return ctx.ErrorRes(500, "Cannot set user admin", err) - } - } - - if invitation.ID != 0 { - if err := invitation.Use(); err != nil { - return ctx.ErrorRes(500, "Cannot use invitation", err) - } - } - - sess.Values["user"] = user.ID - ctx.SaveSession(sess) - - return ctx.RedirectTo("/") -} - -func Login(ctx *context.Context) error { - ctx.SetData("title", ctx.TrH("auth.login")) - ctx.SetData("htmlTitle", ctx.TrH("auth.login")) - ctx.SetData("disableForm", ctx.GetData("DisableLoginForm")) - ctx.SetData("isLoginPage", true) - return ctx.HTML_("auth_form.html") -} - -func ProcessLogin(ctx *context.Context) error { - if ctx.GetData("DisableLoginForm") == true { - return ctx.ErrorRes(403, ctx.Tr("error.login-disabled-form"), nil) - } - - var err error - sess := ctx.GetSession() - - dto := &db.UserDTO{} - if err = ctx.Bind(dto); err != nil { - return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) - } - password := dto.Password - - var user *db.User - - if user, err = db.GetUserByUsername(dto.Username); err != nil { - if !errors.Is(err, gorm.ErrRecordNotFound) { - return ctx.ErrorRes(500, "Cannot get user", err) - } - log.Warn().Msg("Invalid HTTP authentication attempt from " + ctx.RealIP()) - ctx.AddFlash(ctx.Tr("flash.auth.invalid-credentials"), "error") - return ctx.RedirectTo("/login") - } - - if ok, err := utils.Argon2id.Verify(password, user.Password); !ok { - if err != nil { - return ctx.ErrorRes(500, "Cannot check for password", err) - } - log.Warn().Msg("Invalid HTTP authentication attempt from " + ctx.RealIP()) - ctx.AddFlash(ctx.Tr("flash.auth.invalid-credentials"), "error") - return ctx.RedirectTo("/login") - } - - // handle MFA - var hasWebauthn, hasTotp bool - if hasWebauthn, hasTotp, err = user.HasMFA(); err != nil { - return ctx.ErrorRes(500, "Cannot check for user MFA", err) - } - if hasWebauthn || hasTotp { - sess.Values["mfaID"] = user.ID - sess.Options.MaxAge = 5 * 60 // 5 minutes - ctx.SaveSession(sess) - return ctx.RedirectTo("/mfa") - } - - sess.Values["user"] = user.ID - sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year - ctx.SaveSession(sess) - ctx.DeleteCsrfCookie() - - return ctx.RedirectTo("/") -} - -func Mfa(ctx *context.Context) error { - var err error - - user := db.User{ID: ctx.GetSession().Values["mfaID"].(uint)} - - var hasWebauthn, hasTotp bool - if hasWebauthn, hasTotp, err = user.HasMFA(); err != nil { - return ctx.ErrorRes(500, "Cannot check for user MFA", err) - } - - ctx.SetData("hasWebauthn", hasWebauthn) - ctx.SetData("hasTotp", hasTotp) - - return ctx.HTML_("mfa.html") -} - -func OauthCallback(ctx *context.Context) error { - user, err := gothic.CompleteUserAuth(ctx.Response(), ctx.Request()) - if err != nil { - return ctx.ErrorRes(400, ctx.Tr("error.complete-oauth-login", err.Error()), err) - } - - currUser := ctx.User - if currUser != nil { - // if user is logged in, link account to user and update its avatar URL - updateUserProviderInfo(currUser, user.Provider, user) - - if err = currUser.Update(); err != nil { - return ctx.ErrorRes(500, "Cannot update user "+cases.Title(language.English).String(user.Provider)+" id", err) - } - - ctx.AddFlash(ctx.Tr("flash.auth.account-linked-oauth", cases.Title(language.English).String(user.Provider)), "success") - return ctx.RedirectTo("/settings") - } - - // if user is not in database, create it - userDB, err := db.GetUserByProvider(user.UserID, user.Provider) - if err != nil { - if ctx.GetData("DisableSignup") == true { - return ctx.ErrorRes(403, ctx.Tr("error.signup-disabled"), nil) - } - - if !errors.Is(err, gorm.ErrRecordNotFound) { - return ctx.ErrorRes(500, "Cannot get user", err) - } - - if user.NickName == "" { - user.NickName = strings.Split(user.Email, "@")[0] - } - - userDB = &db.User{ - Username: user.NickName, - Email: user.Email, - MD5Hash: fmt.Sprintf("%x", md5.Sum([]byte(strings.ToLower(strings.TrimSpace(user.Email))))), - } - - // set provider id and avatar URL - updateUserProviderInfo(userDB, user.Provider, user) - - if err = userDB.Create(); err != nil { - if db.IsUniqueConstraintViolation(err) { - ctx.AddFlash(ctx.Tr("flash.auth.username-exists"), "error") - return ctx.RedirectTo("/login") - } - - return ctx.ErrorRes(500, "Cannot create user", err) - } - - if userDB.ID == 1 { - if err = userDB.SetAdmin(); err != nil { - return ctx.ErrorRes(500, "Cannot set user admin", err) - } - } - - var resp *http.Response - switch user.Provider { - case GitHubProvider: - resp, err = http.Get("https://github.com/" + user.NickName + ".keys") - case GitLabProvider: - resp, err = http.Get(urlJoin(config.C.GitlabUrl, user.NickName+".keys")) - case GiteaProvider: - resp, err = http.Get(urlJoin(config.C.GiteaUrl, user.NickName+".keys")) - case OpenIDConnect: - err = errors.New("cannot get keys from OIDC provider") - } - - if err == nil { - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - ctx.AddFlash(ctx.Tr("flash.auth.user-sshkeys-not-retrievable"), "error") - log.Error().Err(err).Msg("Could not get user keys") - } - - keys := strings.Split(string(body), "\n") - if len(keys[len(keys)-1]) == 0 { - keys = keys[:len(keys)-1] - } - for _, key := range keys { - sshKey := db.SSHKey{ - Title: "Added from " + user.Provider, - Content: key, - User: *userDB, - } - - if err = sshKey.Create(); err != nil { - ctx.AddFlash(ctx.Tr("flash.auth.user-sshkeys-not-created"), "error") - log.Error().Err(err).Msg("Could not create ssh key") - } - } - } - } - - sess := ctx.GetSession() - sess.Values["user"] = userDB.ID - ctx.SaveSession(sess) - ctx.DeleteCsrfCookie() - - return ctx.RedirectTo("/") -} - -func Oauth(ctx *context.Context) error { - provider := ctx.Param("provider") - - httpProtocol := "http" - if ctx.Request().TLS != nil || ctx.Request().Header.Get("X-Forwarded-Proto") == "https" { - httpProtocol = "https" - } - - forwarded_hdr := ctx.Request().Header.Get("Forwarded") - if forwarded_hdr != "" { - fields := strings.Split(forwarded_hdr, ";") - fwd := make(map[string]string) - for _, v := range fields { - p := strings.Split(v, "=") - fwd[p[0]] = p[1] - } - val, ok := fwd["proto"] - if ok && val == "https" { - httpProtocol = "https" - } - } - - var opengistUrl string - if config.C.ExternalUrl != "" { - opengistUrl = config.C.ExternalUrl - } else { - opengistUrl = httpProtocol + "://" + ctx.Request().Host - } - - switch provider { - case GitHubProvider: - goth.UseProviders( - github.New( - config.C.GithubClientKey, - config.C.GithubSecret, - urlJoin(opengistUrl, "/oauth/github/callback"), - ), - ) - - case GitLabProvider: - goth.UseProviders( - gitlab.NewCustomisedURL( - config.C.GitlabClientKey, - config.C.GitlabSecret, - urlJoin(opengistUrl, "/oauth/gitlab/callback"), - urlJoin(config.C.GitlabUrl, "/oauth/authorize"), - urlJoin(config.C.GitlabUrl, "/oauth/token"), - urlJoin(config.C.GitlabUrl, "/api/v4/user"), - ), - ) - - case GiteaProvider: - goth.UseProviders( - gitea.NewCustomisedURL( - config.C.GiteaClientKey, - config.C.GiteaSecret, - urlJoin(opengistUrl, "/oauth/gitea/callback"), - urlJoin(config.C.GiteaUrl, "/login/oauth/authorize"), - urlJoin(config.C.GiteaUrl, "/login/oauth/access_token"), - urlJoin(config.C.GiteaUrl, "/api/v1/user"), - ), - ) - case OpenIDConnect: - oidcProvider, err := openidConnect.New( - config.C.OIDCClientKey, - config.C.OIDCSecret, - urlJoin(opengistUrl, "/oauth/openid-connect/callback"), - config.C.OIDCDiscoveryUrl, - "openid", - "email", - "profile", - ) - - if err != nil { - return ctx.ErrorRes(500, "Cannot create OIDC provider", err) - } - - goth.UseProviders(oidcProvider) - } - - ctxValue := gocontext.WithValue(ctx.Request().Context(), gothic.ProviderParamKey, provider) - ctx.SetRequest(ctx.Request().WithContext(ctxValue)) - if provider != GitHubProvider && provider != GitLabProvider && provider != GiteaProvider && provider != OpenIDConnect { - return ctx.ErrorRes(400, ctx.Tr("error.oauth-unsupported"), nil) - } - - gothic.BeginAuthHandler(ctx.Response(), ctx.Request()) - return nil -} - -func OauthUnlink(ctx *context.Context) error { - provider := ctx.Param("provider") - - currUser := ctx.User - // Map each provider to a function that checks the relevant ID in currUser - providerIDCheckMap := map[string]func() bool{ - GitHubProvider: func() bool { return currUser.GithubID != "" }, - GitLabProvider: func() bool { return currUser.GitlabID != "" }, - GiteaProvider: func() bool { return currUser.GiteaID != "" }, - OpenIDConnect: func() bool { return currUser.OIDCID != "" }, - } - - if checkFunc, exists := providerIDCheckMap[provider]; exists && checkFunc() { - if err := currUser.DeleteProviderID(provider); err != nil { - return ctx.ErrorRes(500, "Cannot unlink account from "+cases.Title(language.English).String(provider), err) - } - - ctx.AddFlash(ctx.Tr("flash.auth.account-unlinked-oauth", cases.Title(language.English).String(provider)), "success") - return ctx.RedirectTo("/settings") - } - - return ctx.RedirectTo("/settings") -} - -func BeginWebAuthnBinding(ctx *context.Context) error { - credsCreation, jsonWaSession, err := webauthn.BeginBinding(ctx.User) - if err != nil { - return ctx.ErrorRes(500, "Cannot begin WebAuthn registration", err) - } - - sess := ctx.GetSession() - sess.Values["webauthn_registration_session"] = jsonWaSession - sess.Options.MaxAge = 5 * 60 // 5 minutes - ctx.SaveSession(sess) - - return ctx.JSON(200, credsCreation) -} - -func FinishWebAuthnBinding(ctx *context.Context) error { - sess := ctx.GetSession() - jsonWaSession, ok := sess.Values["webauthn_registration_session"].([]byte) - if !ok { - return ctx.ErrorRes(401, "Cannot get WebAuthn registration session", nil) - } - - user := ctx.User - - // extract passkey name from request - body, err := io.ReadAll(ctx.Request().Body) - if err != nil { - return ctx.ErrorRes(400, "Failed to read request body", err) - } - ctx.Request().Body.Close() - ctx.Request().Body = io.NopCloser(bytes.NewBuffer(body)) - - dto := new(db.CrendentialDTO) - _ = gojson.Unmarshal(body, &dto) - - if err = ctx.Validate(dto); err != nil { - return ctx.ErrorRes(400, "Invalid request", err) - } - passkeyName := dto.PasskeyName - if passkeyName == "" { - passkeyName = "WebAuthn" - } - - waCredential, err := webauthn.FinishBinding(user, jsonWaSession, ctx.Request()) - if err != nil { - return ctx.ErrorRes(403, "Failed binding attempt for passkey", err) - } - - if _, err = db.CreateFromCrendential(user.ID, passkeyName, waCredential); err != nil { - return ctx.ErrorRes(500, "Cannot create WebAuthn credential on database", err) - } - - delete(sess.Values, "webauthn_registration_session") - ctx.SaveSession(sess) - - ctx.AddFlash(ctx.Tr("flash.auth.passkey-registred", passkeyName), "success") - return ctx.JSON_([]string{"OK"}) -} - -func BeginWebAuthnLogin(ctx *context.Context) error { - credsCreation, jsonWaSession, err := webauthn.BeginDiscoverableLogin() - if err != nil { - return ctx.ErrorRes(401, "Cannot begin WebAuthn login", err) - } - - sess := ctx.GetSession() - sess.Values["webauthn_login_session"] = jsonWaSession - sess.Options.MaxAge = 5 * 60 // 5 minutes - ctx.SaveSession(sess) - - return ctx.JSON_(credsCreation) -} - -func FinishWebAuthnLogin(ctx *context.Context) error { - sess := ctx.GetSession() - sessionData, ok := sess.Values["webauthn_login_session"].([]byte) - if !ok { - return ctx.ErrorRes(401, "Cannot get WebAuthn login session", nil) - } - - userID, err := webauthn.FinishDiscoverableLogin(sessionData, ctx.Request()) - if err != nil { - return ctx.ErrorRes(403, "Failed authentication attempt for passkey", err) - } - - sess.Values["user"] = userID - sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year - - delete(sess.Values, "webauthn_login_session") - ctx.SaveSession(sess) - - return ctx.JSON_([]string{"OK"}) -} - -func BeginWebAuthnAssertion(ctx *context.Context) error { - sess := ctx.GetSession() - - ogUser, err := db.GetUserById(sess.Values["mfaID"].(uint)) - if err != nil { - return ctx.ErrorRes(500, "Cannot get user", err) - } - - credsCreation, jsonWaSession, err := webauthn.BeginLogin(ogUser) - if err != nil { - return ctx.ErrorRes(401, "Cannot begin WebAuthn login", err) - } - - sess.Values["webauthn_assertion_session"] = jsonWaSession - sess.Options.MaxAge = 5 * 60 // 5 minutes - ctx.SaveSession(sess) - - return ctx.JSON_(credsCreation) -} - -func FinishWebAuthnAssertion(ctx *context.Context) error { - sess := ctx.GetSession() - sessionData, ok := sess.Values["webauthn_assertion_session"].([]byte) - if !ok { - return ctx.ErrorRes(401, "Cannot get WebAuthn assertion session", nil) - } - - userId := sess.Values["mfaID"].(uint) - - ogUser, err := db.GetUserById(userId) - if err != nil { - return ctx.ErrorRes(500, "Cannot get user", err) - } - - if err = webauthn.FinishLogin(ogUser, sessionData, ctx.Request()); err != nil { - return ctx.ErrorRes(403, "Failed authentication attempt for passkey", err) - } - - sess.Values["user"] = userId - sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year - - delete(sess.Values, "webauthn_assertion_session") - delete(sess.Values, "mfaID") - ctx.SaveSession(sess) - - return ctx.JSON_([]string{"OK"}) -} - -func BeginTotp(ctx *context.Context) error { - user := ctx.User - - if _, hasTotp, err := user.HasMFA(); err != nil { - return ctx.ErrorRes(500, "Cannot check for user MFA", err) - } else if hasTotp { - ctx.AddFlash(ctx.Tr("auth.totp.already-enabled"), "error") - return ctx.RedirectTo("/settings") - } - - ogUrl, err := url.Parse(ctx.GetData("baseHttpUrl").(string)) - if err != nil { - return ctx.ErrorRes(500, "Cannot parse base URL", err) - } - - sess := ctx.GetSession() - generatedSecret, _ := sess.Values["generatedSecret"].([]byte) - - totpSecret, qrcode, err, generatedSecret := totp.GenerateQRCode(ctx.User.Username, ogUrl.Hostname(), generatedSecret) - if err != nil { - return ctx.ErrorRes(500, "Cannot generate TOTP QR code", err) - } - sess.Values["totpSecret"] = totpSecret - sess.Values["generatedSecret"] = generatedSecret - ctx.SaveSession(sess) - - ctx.SetData("totpSecret", totpSecret) - ctx.SetData("totpQrcode", qrcode) - - return ctx.HTML_("totp.html") - -} - -func FinishTotp(ctx *context.Context) error { - user := ctx.User - - if _, hasTotp, err := user.HasMFA(); err != nil { - return ctx.ErrorRes(500, "Cannot check for user MFA", err) - } else if hasTotp { - ctx.AddFlash(ctx.Tr("auth.totp.already-enabled"), "error") - return ctx.RedirectTo("/settings") - } - - dto := &db.TOTPDTO{} - if err := ctx.Bind(dto); err != nil { - return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) - } - - if err := ctx.Validate(dto); err != nil { - ctx.AddFlash("Invalid secret", "error") - return ctx.RedirectTo("/settings/totp/generate") - } - - sess := ctx.GetSession() - secret, ok := sess.Values["totpSecret"].(string) - if !ok { - return ctx.ErrorRes(500, "Cannot get TOTP secret from session", nil) - } - - if !totp.Validate(dto.Code, secret) { - ctx.AddFlash(ctx.Tr("auth.totp.invalid-code"), "error") - - return ctx.RedirectTo("/settings/totp/generate") - } - - userTotp := &db.TOTP{ - UserID: ctx.User.ID, - } - if err := userTotp.StoreSecret(secret); err != nil { - return ctx.ErrorRes(500, "Cannot store TOTP secret", err) - } - - if err := userTotp.Create(); err != nil { - return ctx.ErrorRes(500, "Cannot create TOTP", err) - } - - ctx.AddFlash("TOTP successfully enabled", "success") - codes, err := userTotp.GenerateRecoveryCodes() - if err != nil { - return ctx.ErrorRes(500, "Cannot generate recovery codes", err) - } - - delete(sess.Values, "totpSecret") - delete(sess.Values, "generatedSecret") - ctx.SaveSession(sess) - - ctx.SetData("recoveryCodes", codes) - return ctx.HTML_("totp.html") -} - -func AssertTotp(ctx *context.Context) error { - var err error - dto := &db.TOTPDTO{} - if err := ctx.Bind(dto); err != nil { - return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) - } - - if err := ctx.Validate(dto); err != nil { - ctx.AddFlash(ctx.Tr("auth.totp.invalid-code"), "error") - return ctx.RedirectTo("/mfa") - } - - sess := ctx.GetSession() - userId := sess.Values["mfaID"].(uint) - var userTotp *db.TOTP - if userTotp, err = db.GetTOTPByUserID(userId); err != nil { - return ctx.ErrorRes(500, "Cannot get TOTP by UID", err) - } - - redirectUrl := "/" - - var validCode, validRecoveryCode bool - if validCode, err = userTotp.ValidateCode(dto.Code); err != nil { - return ctx.ErrorRes(500, "Cannot validate TOTP code", err) - } - if !validCode { - validRecoveryCode, err = userTotp.ValidateRecoveryCode(dto.Code) - if err != nil { - return ctx.ErrorRes(500, "Cannot validate TOTP code", err) - } - - if !validRecoveryCode { - ctx.AddFlash(ctx.Tr("auth.totp.invalid-code"), "error") - return ctx.RedirectTo("/mfa") - } - - ctx.AddFlash(ctx.Tr("auth.totp.code-used", dto.Code), "warning") - redirectUrl = "/settings" - } - - sess.Values["user"] = userId - sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year - delete(sess.Values, "mfaID") - ctx.SaveSession(sess) - - return ctx.RedirectTo(redirectUrl) -} - -func DisableTotp(ctx *context.Context) error { - user := ctx.User - userTotp, err := db.GetTOTPByUserID(user.ID) - if err != nil { - return ctx.ErrorRes(500, "Cannot get TOTP by UID", err) - } - - if err = userTotp.Delete(); err != nil { - return ctx.ErrorRes(500, "Cannot delete TOTP", err) - } - - ctx.AddFlash(ctx.Tr("auth.totp.disabled"), "success") - return ctx.RedirectTo("/settings") -} - -func RegenerateTotpRecoveryCodes(ctx *context.Context) error { - user := ctx.User - userTotp, err := db.GetTOTPByUserID(user.ID) - if err != nil { - return ctx.ErrorRes(500, "Cannot get TOTP by UID", err) - } - - codes, err := userTotp.GenerateRecoveryCodes() - if err != nil { - return ctx.ErrorRes(500, "Cannot generate recovery codes", err) - } - - ctx.SetData("recoveryCodes", codes) - return ctx.HTML_("totp.html") -} - -func Logout(ctx *context.Context) error { - ctx.DeleteSession() - ctx.DeleteCsrfCookie() - return ctx.RedirectTo("/all") -} - -func urlJoin(base string, elem ...string) string { - joined, err := url.JoinPath(base, elem...) - if err != nil { - log.Error().Err(err).Msg("Cannot join url") - } - - return joined -} - -func updateUserProviderInfo(userDB *db.User, provider string, user goth.User) { - userDB.AvatarURL = getAvatarUrlFromProvider(provider, user.UserID) - switch provider { - case GitHubProvider: - userDB.GithubID = user.UserID - case GitLabProvider: - userDB.GitlabID = user.UserID - case GiteaProvider: - userDB.GiteaID = user.UserID - case OpenIDConnect: - userDB.OIDCID = user.UserID - userDB.AvatarURL = user.AvatarURL - } -} - -func getAvatarUrlFromProvider(provider string, identifier string) string { - switch provider { - case GitHubProvider: - return "https://avatars.githubusercontent.com/u/" + identifier + "?v=4" - case GitLabProvider: - return urlJoin(config.C.GitlabUrl, "/uploads/-/system/user/avatar/", identifier, "/avatar.png") + "?width=400" - case GiteaProvider: - resp, err := http.Get(urlJoin(config.C.GiteaUrl, "/api/v1/users/", identifier)) - if err != nil { - log.Error().Err(err).Msg("Cannot get user from Gitea") - return "" - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - log.Error().Err(err).Msg("Cannot read Gitea response body") - return "" - } - - var result map[string]interface{} - err = gojson.Unmarshal(body, &result) - if err != nil { - log.Error().Err(err).Msg("Cannot unmarshal Gitea response body") - return "" - } - - field, ok := result["avatar_url"] - if !ok { - log.Error().Msg("Field 'avatar_url' not found in Gitea JSON response") - return "" - } - return field.(string) - } - return "" -} - type ContextAuthInfo struct { Context *context.Context } diff --git a/internal/web/handlers/auth/mfa.go b/internal/web/handlers/auth/mfa.go new file mode 100644 index 0000000..e002c01 --- /dev/null +++ b/internal/web/handlers/auth/mfa.go @@ -0,0 +1,22 @@ +package auth + +import ( + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" +) + +func Mfa(ctx *context.Context) error { + var err error + + user := db.User{ID: ctx.GetSession().Values["mfaID"].(uint)} + + var hasWebauthn, hasTotp bool + if hasWebauthn, hasTotp, err = user.HasMFA(); err != nil { + return ctx.ErrorRes(500, "Cannot check for user MFA", err) + } + + ctx.SetData("hasWebauthn", hasWebauthn) + ctx.SetData("hasTotp", hasTotp) + + return ctx.HTML_("mfa.html") +} diff --git a/internal/web/handlers/auth/oauth.go b/internal/web/handlers/auth/oauth.go new file mode 100644 index 0000000..9bc332b --- /dev/null +++ b/internal/web/handlers/auth/oauth.go @@ -0,0 +1,266 @@ +package auth + +import ( + gocontext "context" + "crypto/md5" + gojson "encoding/json" + "errors" + "fmt" + "github.com/markbates/goth" + "github.com/markbates/goth/gothic" + "github.com/rs/zerolog/log" + "github.com/thomiceli/opengist/internal/auth/oauth" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" + "golang.org/x/text/cases" + "golang.org/x/text/language" + "gorm.io/gorm" + "io" + "net/http" + "net/url" + "strings" +) + +const ( + GitHubProvider = "github" + GitLabProvider = "gitlab" + GiteaProvider = "gitea" + OpenIDConnect = "openid-connect" +) + +func Oauth(ctx *context.Context) error { + provider := ctx.Param("provider") + + httpProtocol := "http" + if ctx.Request().TLS != nil || ctx.Request().Header.Get("X-Forwarded-Proto") == "https" { + httpProtocol = "https" + } + + forwarded_hdr := ctx.Request().Header.Get("Forwarded") + if forwarded_hdr != "" { + fields := strings.Split(forwarded_hdr, ";") + fwd := make(map[string]string) + for _, v := range fields { + p := strings.Split(v, "=") + fwd[p[0]] = p[1] + } + val, ok := fwd["proto"] + if ok && val == "https" { + httpProtocol = "https" + } + } + + var opengistUrl string + if config.C.ExternalUrl != "" { + opengistUrl = config.C.ExternalUrl + } else { + opengistUrl = httpProtocol + "://" + ctx.Request().Host + } + + providerr := oauth.DefineProvider(provider, opengistUrl) + if providerr == nil { + return ctx.ErrorRes(400, ctx.Tr("error.oauth-unsupported"), nil) + } + + if err := providerr.RegisterProvider(); err != nil { + return ctx.ErrorRes(500, "Cannot create provider", err) + } + + ctxValue := gocontext.WithValue(ctx.Request().Context(), gothic.ProviderParamKey, provider) + ctx.SetRequest(ctx.Request().WithContext(ctxValue)) + + gothic.BeginAuthHandler(ctx.Response(), ctx.Request()) + return nil +} + +func OauthCallback(ctx *context.Context) error { + user, err := gothic.CompleteUserAuth(ctx.Response(), ctx.Request()) + if err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.complete-oauth-login", err.Error()), err) + } + + currUser := ctx.User + if currUser != nil { + // if user is logged in, link account to user and update its avatar URL + updateUserProviderInfo(currUser, user.Provider, user) + + if err = currUser.Update(); err != nil { + return ctx.ErrorRes(500, "Cannot update user "+cases.Title(language.English).String(user.Provider)+" id", err) + } + + ctx.AddFlash(ctx.Tr("flash.auth.account-linked-oauth", cases.Title(language.English).String(user.Provider)), "success") + return ctx.RedirectTo("/settings") + } + + // if user is not in database, create it + userDB, err := db.GetUserByProvider(user.UserID, user.Provider) + if err != nil { + if ctx.GetData("DisableSignup") == true { + return ctx.ErrorRes(403, ctx.Tr("error.signup-disabled"), nil) + } + + if !errors.Is(err, gorm.ErrRecordNotFound) { + return ctx.ErrorRes(500, "Cannot get user", err) + } + + if user.NickName == "" { + user.NickName = strings.Split(user.Email, "@")[0] + } + + userDB = &db.User{ + Username: user.NickName, + Email: user.Email, + MD5Hash: fmt.Sprintf("%x", md5.Sum([]byte(strings.ToLower(strings.TrimSpace(user.Email))))), + } + + // set provider id and avatar URL + updateUserProviderInfo(userDB, user.Provider, user) + + if err = userDB.Create(); err != nil { + if db.IsUniqueConstraintViolation(err) { + ctx.AddFlash(ctx.Tr("flash.auth.username-exists"), "error") + return ctx.RedirectTo("/login") + } + + return ctx.ErrorRes(500, "Cannot create user", err) + } + + if userDB.ID == 1 { + if err = userDB.SetAdmin(); err != nil { + return ctx.ErrorRes(500, "Cannot set user admin", err) + } + } + + var resp *http.Response + switch user.Provider { + case GitHubProvider: + resp, err = http.Get("https://github.com/" + user.NickName + ".keys") + case GitLabProvider: + resp, err = http.Get(urlJoin(config.C.GitlabUrl, user.NickName+".keys")) + case GiteaProvider: + resp, err = http.Get(urlJoin(config.C.GiteaUrl, user.NickName+".keys")) + case OpenIDConnect: + err = errors.New("cannot get keys from OIDC provider") + } + + if err == nil { + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + ctx.AddFlash(ctx.Tr("flash.auth.user-sshkeys-not-retrievable"), "error") + log.Error().Err(err).Msg("Could not get user keys") + } + + keys := strings.Split(string(body), "\n") + if len(keys[len(keys)-1]) == 0 { + keys = keys[:len(keys)-1] + } + for _, key := range keys { + sshKey := db.SSHKey{ + Title: "Added from " + user.Provider, + Content: key, + User: *userDB, + } + + if err = sshKey.Create(); err != nil { + ctx.AddFlash(ctx.Tr("flash.auth.user-sshkeys-not-created"), "error") + log.Error().Err(err).Msg("Could not create ssh key") + } + } + } + } + + sess := ctx.GetSession() + sess.Values["user"] = userDB.ID + ctx.SaveSession(sess) + ctx.DeleteCsrfCookie() + + return ctx.RedirectTo("/") +} + +func OauthUnlink(ctx *context.Context) error { + provider := ctx.Param("provider") + + currUser := ctx.User + // Map each provider to a function that checks the relevant ID in currUser + providerIDCheckMap := map[string]func() bool{ + GitHubProvider: func() bool { return currUser.GithubID != "" }, + GitLabProvider: func() bool { return currUser.GitlabID != "" }, + GiteaProvider: func() bool { return currUser.GiteaID != "" }, + OpenIDConnect: func() bool { return currUser.OIDCID != "" }, + } + + if checkFunc, exists := providerIDCheckMap[provider]; exists && checkFunc() { + if err := currUser.DeleteProviderID(provider); err != nil { + return ctx.ErrorRes(500, "Cannot unlink account from "+cases.Title(language.English).String(provider), err) + } + + ctx.AddFlash(ctx.Tr("flash.auth.account-unlinked-oauth", cases.Title(language.English).String(provider)), "success") + return ctx.RedirectTo("/settings") + } + + return ctx.RedirectTo("/settings") +} + +func updateUserProviderInfo(userDB *db.User, provider string, user goth.User) { + userDB.AvatarURL = getAvatarUrlFromProvider(provider, user.UserID) + switch provider { + case GitHubProvider: + userDB.GithubID = user.UserID + case GitLabProvider: + userDB.GitlabID = user.UserID + case GiteaProvider: + userDB.GiteaID = user.UserID + case OpenIDConnect: + userDB.OIDCID = user.UserID + userDB.AvatarURL = user.AvatarURL + } +} + +func getAvatarUrlFromProvider(provider string, identifier string) string { + switch provider { + case GitHubProvider: + return "https://avatars.githubusercontent.com/u/" + identifier + "?v=4" + case GitLabProvider: + return urlJoin(config.C.GitlabUrl, "/uploads/-/system/user/avatar/", identifier, "/avatar.png") + "?width=400" + case GiteaProvider: + resp, err := http.Get(urlJoin(config.C.GiteaUrl, "/api/v1/users/", identifier)) + if err != nil { + log.Error().Err(err).Msg("Cannot get user from Gitea") + return "" + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Error().Err(err).Msg("Cannot read Gitea response body") + return "" + } + + var result map[string]interface{} + err = gojson.Unmarshal(body, &result) + if err != nil { + log.Error().Err(err).Msg("Cannot unmarshal Gitea response body") + return "" + } + + field, ok := result["avatar_url"] + if !ok { + log.Error().Msg("Field 'avatar_url' not found in Gitea JSON response") + return "" + } + return field.(string) + } + return "" +} + +func urlJoin(base string, elem ...string) string { + joined, err := url.JoinPath(base, elem...) + if err != nil { + log.Error().Err(err).Msg("Cannot join url") + } + + return joined +} diff --git a/internal/web/handlers/auth/password.go b/internal/web/handlers/auth/password.go new file mode 100644 index 0000000..1ea3143 --- /dev/null +++ b/internal/web/handlers/auth/password.go @@ -0,0 +1,169 @@ +package auth + +import ( + "errors" + "github.com/rs/zerolog/log" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/i18n" + "github.com/thomiceli/opengist/internal/utils" + "github.com/thomiceli/opengist/internal/web/context" + "gorm.io/gorm" +) + +func Register(ctx *context.Context) error { + disableSignup := ctx.GetData("DisableSignup") + disableForm := ctx.GetData("DisableLoginForm") + + code := ctx.QueryParam("code") + if code != "" { + if invitation, err := db.GetInvitationByCode(code); err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return ctx.ErrorRes(500, "Cannot check for invitation code", err) + } else if invitation != nil && invitation.IsUsable() { + disableSignup = false + } + } + + ctx.SetData("title", ctx.TrH("auth.new-account")) + ctx.SetData("htmlTitle", ctx.TrH("auth.new-account")) + ctx.SetData("disableForm", disableForm) + ctx.SetData("disableSignup", disableSignup) + ctx.SetData("isLoginPage", false) + return ctx.HTML_("auth_form.html") +} + +func ProcessRegister(ctx *context.Context) error { + disableSignup := ctx.GetData("DisableSignup") + + code := ctx.QueryParam("code") + invitation, err := db.GetInvitationByCode(code) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return ctx.ErrorRes(500, "Cannot check for invitation code", err) + } else if invitation.ID != 0 && invitation.IsUsable() { + disableSignup = false + } + + if disableSignup == true { + return ctx.ErrorRes(403, ctx.Tr("error.signup-disabled"), nil) + } + + if ctx.GetData("DisableLoginForm") == true { + return ctx.ErrorRes(403, ctx.Tr("error.signup-disabled-form"), nil) + } + + ctx.SetData("title", ctx.TrH("auth.new-account")) + ctx.SetData("htmlTitle", ctx.TrH("auth.new-account")) + + sess := ctx.GetSession() + + dto := new(db.UserDTO) + if err := ctx.Bind(dto); err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) + } + + if err := ctx.Validate(dto); err != nil { + ctx.AddFlash(utils.ValidationMessages(&err, ctx.GetData("locale").(*i18n.Locale)), "error") + return ctx.HTML_("auth_form.html") + } + + if exists, err := db.UserExists(dto.Username); err != nil || exists { + ctx.AddFlash(ctx.Tr("flash.auth.username-exists"), "error") + return ctx.HTML_("auth_form.html") + } + + user := dto.ToUser() + + password, err := utils.Argon2id.Hash(user.Password) + if err != nil { + return ctx.ErrorRes(500, "Cannot hash password", err) + } + user.Password = password + + if err = user.Create(); err != nil { + return ctx.ErrorRes(500, "Cannot create user", err) + } + + if user.ID == 1 { + if err = user.SetAdmin(); err != nil { + return ctx.ErrorRes(500, "Cannot set user admin", err) + } + } + + if invitation.ID != 0 { + if err := invitation.Use(); err != nil { + return ctx.ErrorRes(500, "Cannot use invitation", err) + } + } + + sess.Values["user"] = user.ID + ctx.SaveSession(sess) + + return ctx.RedirectTo("/") +} + +func Login(ctx *context.Context) error { + ctx.SetData("title", ctx.TrH("auth.login")) + ctx.SetData("htmlTitle", ctx.TrH("auth.login")) + ctx.SetData("disableForm", ctx.GetData("DisableLoginForm")) + ctx.SetData("isLoginPage", true) + return ctx.HTML_("auth_form.html") +} + +func ProcessLogin(ctx *context.Context) error { + if ctx.GetData("DisableLoginForm") == true { + return ctx.ErrorRes(403, ctx.Tr("error.login-disabled-form"), nil) + } + + var err error + sess := ctx.GetSession() + + dto := &db.UserDTO{} + if err = ctx.Bind(dto); err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) + } + password := dto.Password + + var user *db.User + + if user, err = db.GetUserByUsername(dto.Username); err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return ctx.ErrorRes(500, "Cannot get user", err) + } + log.Warn().Msg("Invalid HTTP authentication attempt from " + ctx.RealIP()) + ctx.AddFlash(ctx.Tr("flash.auth.invalid-credentials"), "error") + return ctx.RedirectTo("/login") + } + + if ok, err := utils.Argon2id.Verify(password, user.Password); !ok { + if err != nil { + return ctx.ErrorRes(500, "Cannot check for password", err) + } + log.Warn().Msg("Invalid HTTP authentication attempt from " + ctx.RealIP()) + ctx.AddFlash(ctx.Tr("flash.auth.invalid-credentials"), "error") + return ctx.RedirectTo("/login") + } + + // handle MFA + var hasWebauthn, hasTotp bool + if hasWebauthn, hasTotp, err = user.HasMFA(); err != nil { + return ctx.ErrorRes(500, "Cannot check for user MFA", err) + } + if hasWebauthn || hasTotp { + sess.Values["mfaID"] = user.ID + sess.Options.MaxAge = 5 * 60 // 5 minutes + ctx.SaveSession(sess) + return ctx.RedirectTo("/mfa") + } + + sess.Values["user"] = user.ID + sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year + ctx.SaveSession(sess) + ctx.DeleteCsrfCookie() + + return ctx.RedirectTo("/") +} + +func Logout(ctx *context.Context) error { + ctx.DeleteSession() + ctx.DeleteCsrfCookie() + return ctx.RedirectTo("/all") +} diff --git a/internal/web/handlers/auth/totp.go b/internal/web/handlers/auth/totp.go new file mode 100644 index 0000000..460103a --- /dev/null +++ b/internal/web/handlers/auth/totp.go @@ -0,0 +1,177 @@ +package auth + +import ( + "github.com/thomiceli/opengist/internal/auth/totp" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" + "net/url" +) + +func BeginTotp(ctx *context.Context) error { + user := ctx.User + + if _, hasTotp, err := user.HasMFA(); err != nil { + return ctx.ErrorRes(500, "Cannot check for user MFA", err) + } else if hasTotp { + ctx.AddFlash(ctx.Tr("auth.totp.already-enabled"), "error") + return ctx.RedirectTo("/settings") + } + + ogUrl, err := url.Parse(ctx.GetData("baseHttpUrl").(string)) + if err != nil { + return ctx.ErrorRes(500, "Cannot parse base URL", err) + } + + sess := ctx.GetSession() + generatedSecret, _ := sess.Values["generatedSecret"].([]byte) + + totpSecret, qrcode, err, generatedSecret := totp.GenerateQRCode(ctx.User.Username, ogUrl.Hostname(), generatedSecret) + if err != nil { + return ctx.ErrorRes(500, "Cannot generate TOTP QR code", err) + } + sess.Values["totpSecret"] = totpSecret + sess.Values["generatedSecret"] = generatedSecret + ctx.SaveSession(sess) + + ctx.SetData("totpSecret", totpSecret) + ctx.SetData("totpQrcode", qrcode) + + return ctx.HTML_("totp.html") + +} + +func FinishTotp(ctx *context.Context) error { + user := ctx.User + + if _, hasTotp, err := user.HasMFA(); err != nil { + return ctx.ErrorRes(500, "Cannot check for user MFA", err) + } else if hasTotp { + ctx.AddFlash(ctx.Tr("auth.totp.already-enabled"), "error") + return ctx.RedirectTo("/settings") + } + + dto := &db.TOTPDTO{} + if err := ctx.Bind(dto); err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) + } + + if err := ctx.Validate(dto); err != nil { + ctx.AddFlash("Invalid secret", "error") + return ctx.RedirectTo("/settings/totp/generate") + } + + sess := ctx.GetSession() + secret, ok := sess.Values["totpSecret"].(string) + if !ok { + return ctx.ErrorRes(500, "Cannot get TOTP secret from session", nil) + } + + if !totp.Validate(dto.Code, secret) { + ctx.AddFlash(ctx.Tr("auth.totp.invalid-code"), "error") + + return ctx.RedirectTo("/settings/totp/generate") + } + + userTotp := &db.TOTP{ + UserID: ctx.User.ID, + } + if err := userTotp.StoreSecret(secret); err != nil { + return ctx.ErrorRes(500, "Cannot store TOTP secret", err) + } + + if err := userTotp.Create(); err != nil { + return ctx.ErrorRes(500, "Cannot create TOTP", err) + } + + ctx.AddFlash("TOTP successfully enabled", "success") + codes, err := userTotp.GenerateRecoveryCodes() + if err != nil { + return ctx.ErrorRes(500, "Cannot generate recovery codes", err) + } + + delete(sess.Values, "totpSecret") + delete(sess.Values, "generatedSecret") + ctx.SaveSession(sess) + + ctx.SetData("recoveryCodes", codes) + return ctx.HTML_("totp.html") +} + +func AssertTotp(ctx *context.Context) error { + var err error + dto := &db.TOTPDTO{} + if err := ctx.Bind(dto); err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) + } + + if err := ctx.Validate(dto); err != nil { + ctx.AddFlash(ctx.Tr("auth.totp.invalid-code"), "error") + return ctx.RedirectTo("/mfa") + } + + sess := ctx.GetSession() + userId := sess.Values["mfaID"].(uint) + var userTotp *db.TOTP + if userTotp, err = db.GetTOTPByUserID(userId); err != nil { + return ctx.ErrorRes(500, "Cannot get TOTP by UID", err) + } + + redirectUrl := "/" + + var validCode, validRecoveryCode bool + if validCode, err = userTotp.ValidateCode(dto.Code); err != nil { + return ctx.ErrorRes(500, "Cannot validate TOTP code", err) + } + if !validCode { + validRecoveryCode, err = userTotp.ValidateRecoveryCode(dto.Code) + if err != nil { + return ctx.ErrorRes(500, "Cannot validate TOTP code", err) + } + + if !validRecoveryCode { + ctx.AddFlash(ctx.Tr("auth.totp.invalid-code"), "error") + return ctx.RedirectTo("/mfa") + } + + ctx.AddFlash(ctx.Tr("auth.totp.code-used", dto.Code), "warning") + redirectUrl = "/settings" + } + + sess.Values["user"] = userId + sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year + delete(sess.Values, "mfaID") + ctx.SaveSession(sess) + + return ctx.RedirectTo(redirectUrl) +} + +func DisableTotp(ctx *context.Context) error { + user := ctx.User + userTotp, err := db.GetTOTPByUserID(user.ID) + if err != nil { + return ctx.ErrorRes(500, "Cannot get TOTP by UID", err) + } + + if err = userTotp.Delete(); err != nil { + return ctx.ErrorRes(500, "Cannot delete TOTP", err) + } + + ctx.AddFlash(ctx.Tr("auth.totp.disabled"), "success") + return ctx.RedirectTo("/settings") +} + +func RegenerateTotpRecoveryCodes(ctx *context.Context) error { + user := ctx.User + userTotp, err := db.GetTOTPByUserID(user.ID) + if err != nil { + return ctx.ErrorRes(500, "Cannot get TOTP by UID", err) + } + + codes, err := userTotp.GenerateRecoveryCodes() + if err != nil { + return ctx.ErrorRes(500, "Cannot generate recovery codes", err) + } + + ctx.SetData("recoveryCodes", codes) + return ctx.HTML_("totp.html") +} diff --git a/internal/web/handlers/auth/webauthn.go b/internal/web/handlers/auth/webauthn.go new file mode 100644 index 0000000..3e8dcd4 --- /dev/null +++ b/internal/web/handlers/auth/webauthn.go @@ -0,0 +1,151 @@ +package auth + +import ( + "bytes" + gojson "encoding/json" + "github.com/thomiceli/opengist/internal/auth/webauthn" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" + "io" +) + +func BeginWebAuthnBinding(ctx *context.Context) error { + credsCreation, jsonWaSession, err := webauthn.BeginBinding(ctx.User) + if err != nil { + return ctx.ErrorRes(500, "Cannot begin WebAuthn registration", err) + } + + sess := ctx.GetSession() + sess.Values["webauthn_registration_session"] = jsonWaSession + sess.Options.MaxAge = 5 * 60 // 5 minutes + ctx.SaveSession(sess) + + return ctx.JSON(200, credsCreation) +} + +func FinishWebAuthnBinding(ctx *context.Context) error { + sess := ctx.GetSession() + jsonWaSession, ok := sess.Values["webauthn_registration_session"].([]byte) + if !ok { + return ctx.ErrorRes(401, "Cannot get WebAuthn registration session", nil) + } + + user := ctx.User + + // extract passkey name from request + body, err := io.ReadAll(ctx.Request().Body) + if err != nil { + return ctx.ErrorRes(400, "Failed to read request body", err) + } + ctx.Request().Body.Close() + ctx.Request().Body = io.NopCloser(bytes.NewBuffer(body)) + + dto := new(db.CrendentialDTO) + _ = gojson.Unmarshal(body, &dto) + + if err = ctx.Validate(dto); err != nil { + return ctx.ErrorRes(400, "Invalid request", err) + } + passkeyName := dto.PasskeyName + if passkeyName == "" { + passkeyName = "WebAuthn" + } + + waCredential, err := webauthn.FinishBinding(user, jsonWaSession, ctx.Request()) + if err != nil { + return ctx.ErrorRes(403, "Failed binding attempt for passkey", err) + } + + if _, err = db.CreateFromCrendential(user.ID, passkeyName, waCredential); err != nil { + return ctx.ErrorRes(500, "Cannot create WebAuthn credential on database", err) + } + + delete(sess.Values, "webauthn_registration_session") + ctx.SaveSession(sess) + + ctx.AddFlash(ctx.Tr("flash.auth.passkey-registred", passkeyName), "success") + return ctx.JSON_([]string{"OK"}) +} + +func BeginWebAuthnLogin(ctx *context.Context) error { + credsCreation, jsonWaSession, err := webauthn.BeginDiscoverableLogin() + if err != nil { + return ctx.ErrorRes(401, "Cannot begin WebAuthn login", err) + } + + sess := ctx.GetSession() + sess.Values["webauthn_login_session"] = jsonWaSession + sess.Options.MaxAge = 5 * 60 // 5 minutes + ctx.SaveSession(sess) + + return ctx.JSON_(credsCreation) +} + +func FinishWebAuthnLogin(ctx *context.Context) error { + sess := ctx.GetSession() + sessionData, ok := sess.Values["webauthn_login_session"].([]byte) + if !ok { + return ctx.ErrorRes(401, "Cannot get WebAuthn login session", nil) + } + + userID, err := webauthn.FinishDiscoverableLogin(sessionData, ctx.Request()) + if err != nil { + return ctx.ErrorRes(403, "Failed authentication attempt for passkey", err) + } + + sess.Values["user"] = userID + sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year + + delete(sess.Values, "webauthn_login_session") + ctx.SaveSession(sess) + + return ctx.JSON_([]string{"OK"}) +} + +func BeginWebAuthnAssertion(ctx *context.Context) error { + sess := ctx.GetSession() + + ogUser, err := db.GetUserById(sess.Values["mfaID"].(uint)) + if err != nil { + return ctx.ErrorRes(500, "Cannot get user", err) + } + + credsCreation, jsonWaSession, err := webauthn.BeginLogin(ogUser) + if err != nil { + return ctx.ErrorRes(401, "Cannot begin WebAuthn login", err) + } + + sess.Values["webauthn_assertion_session"] = jsonWaSession + sess.Options.MaxAge = 5 * 60 // 5 minutes + ctx.SaveSession(sess) + + return ctx.JSON_(credsCreation) +} + +func FinishWebAuthnAssertion(ctx *context.Context) error { + sess := ctx.GetSession() + sessionData, ok := sess.Values["webauthn_assertion_session"].([]byte) + if !ok { + return ctx.ErrorRes(401, "Cannot get WebAuthn assertion session", nil) + } + + userId := sess.Values["mfaID"].(uint) + + ogUser, err := db.GetUserById(userId) + if err != nil { + return ctx.ErrorRes(500, "Cannot get user", err) + } + + if err = webauthn.FinishLogin(ogUser, sessionData, ctx.Request()); err != nil { + return ctx.ErrorRes(403, "Failed authentication attempt for passkey", err) + } + + sess.Values["user"] = userId + sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year + + delete(sess.Values, "webauthn_assertion_session") + delete(sess.Values, "mfaID") + ctx.SaveSession(sess) + + return ctx.JSON_([]string{"OK"}) +} diff --git a/internal/web/server/router.go b/internal/web/server/router.go index 48759e6..15cd739 100644 --- a/internal/web/server/router.go +++ b/internal/web/server/router.go @@ -7,6 +7,7 @@ import ( "github.com/thomiceli/opengist/internal/web/context" "github.com/thomiceli/opengist/internal/web/handlers" "github.com/thomiceli/opengist/internal/web/handlers/admin" + "github.com/thomiceli/opengist/internal/web/handlers/auth" "github.com/thomiceli/opengist/public" "net/http" "os" @@ -27,22 +28,22 @@ func (s *Server) registerRoutes() { r.GET("/healthcheck", handlers.Healthcheck) r.GET("/metrics", handlers.Metrics) - r.GET("/register", handlers.Register) - r.POST("/register", handlers.ProcessRegister) - r.GET("/login", handlers.Login) - r.POST("/login", handlers.ProcessLogin) - r.GET("/logout", handlers.Logout) - r.GET("/oauth/:provider", handlers.Oauth) - r.GET("/oauth/:provider/callback", handlers.OauthCallback) - r.GET("/oauth/:provider/unlink", handlers.OauthUnlink, logged) - r.POST("/webauthn/bind", handlers.BeginWebAuthnBinding, logged) - r.POST("/webauthn/bind/finish", handlers.FinishWebAuthnBinding, logged) - r.POST("/webauthn/login", handlers.BeginWebAuthnLogin) - r.POST("/webauthn/login/finish", handlers.FinishWebAuthnLogin) - r.POST("/webauthn/assertion", handlers.BeginWebAuthnAssertion, inMFASession) - r.POST("/webauthn/assertion/finish", handlers.FinishWebAuthnAssertion, inMFASession) - r.GET("/mfa", handlers.Mfa, inMFASession) - r.POST("/mfa/totp/assertion", handlers.AssertTotp, inMFASession) + r.GET("/register", auth.Register) + r.POST("/register", auth.ProcessRegister) + r.GET("/login", auth.Login) + r.POST("/login", auth.ProcessLogin) + r.GET("/logout", auth.Logout) + r.GET("/oauth/:provider", auth.Oauth) + r.GET("/oauth/:provider/callback", auth.OauthCallback) + r.GET("/oauth/:provider/unlink", auth.OauthUnlink, logged) + r.POST("/webauthn/bind", auth.BeginWebAuthnBinding, logged) + r.POST("/webauthn/bind/finish", auth.FinishWebAuthnBinding, logged) + r.POST("/webauthn/login", auth.BeginWebAuthnLogin) + r.POST("/webauthn/login/finish", auth.FinishWebAuthnLogin) + r.POST("/webauthn/assertion", auth.BeginWebAuthnAssertion, inMFASession) + r.POST("/webauthn/assertion/finish", auth.FinishWebAuthnAssertion, inMFASession) + r.GET("/mfa", auth.Mfa, inMFASession) + r.POST("/mfa/totp/assertion", auth.AssertTotp, inMFASession) sA := r.SubGroup("/settings") { @@ -55,10 +56,10 @@ func (s *Server) registerRoutes() { sA.DELETE("/passkeys/:id", handlers.PasskeyDelete) sA.PUT("/password", handlers.PasswordProcess) sA.PUT("/username", handlers.UsernameProcess) - sA.GET("/totp/generate", handlers.BeginTotp) - sA.POST("/totp/generate", handlers.FinishTotp) - sA.DELETE("/totp", handlers.DisableTotp) - sA.POST("/totp/regenerate", handlers.RegenerateTotpRecoveryCodes) + sA.GET("/totp/generate", auth.BeginTotp) + sA.POST("/totp/generate", auth.FinishTotp) + sA.DELETE("/totp", auth.DisableTotp) + sA.POST("/totp/regenerate", auth.RegenerateTotpRecoveryCodes) } sB := r.SubGroup("/admin-panel")