opengist/internal/web/handlers/auth/oauth.go

267 lines
7.3 KiB
Go
Raw Normal View History

2025-01-11 19:17:01 +00:00
package auth
import (
gocontext "context"
"crypto/md5"
gojson "encoding/json"
"errors"
"fmt"
"github.com/markbates/goth"
"github.com/markbates/goth/gothic"
"github.com/rs/zerolog/log"
"github.com/thomiceli/opengist/internal/auth/oauth"
"github.com/thomiceli/opengist/internal/config"
"github.com/thomiceli/opengist/internal/db"
"github.com/thomiceli/opengist/internal/web/context"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"gorm.io/gorm"
"io"
"net/http"
"net/url"
"strings"
)
const (
GitHubProvider = "github"
GitLabProvider = "gitlab"
GiteaProvider = "gitea"
OpenIDConnect = "openid-connect"
)
func Oauth(ctx *context.Context) error {
provider := ctx.Param("provider")
httpProtocol := "http"
if ctx.Request().TLS != nil || ctx.Request().Header.Get("X-Forwarded-Proto") == "https" {
httpProtocol = "https"
}
forwarded_hdr := ctx.Request().Header.Get("Forwarded")
if forwarded_hdr != "" {
fields := strings.Split(forwarded_hdr, ";")
fwd := make(map[string]string)
for _, v := range fields {
p := strings.Split(v, "=")
fwd[p[0]] = p[1]
}
val, ok := fwd["proto"]
if ok && val == "https" {
httpProtocol = "https"
}
}
var opengistUrl string
if config.C.ExternalUrl != "" {
opengistUrl = config.C.ExternalUrl
} else {
opengistUrl = httpProtocol + "://" + ctx.Request().Host
}
providerr := oauth.DefineProvider(provider, opengistUrl)
if providerr == nil {
return ctx.ErrorRes(400, ctx.Tr("error.oauth-unsupported"), nil)
}
if err := providerr.RegisterProvider(); err != nil {
return ctx.ErrorRes(500, "Cannot create provider", err)
}
ctxValue := gocontext.WithValue(ctx.Request().Context(), gothic.ProviderParamKey, provider)
ctx.SetRequest(ctx.Request().WithContext(ctxValue))
gothic.BeginAuthHandler(ctx.Response(), ctx.Request())
return nil
}
func OauthCallback(ctx *context.Context) error {
user, err := gothic.CompleteUserAuth(ctx.Response(), ctx.Request())
if err != nil {
return ctx.ErrorRes(400, ctx.Tr("error.complete-oauth-login", err.Error()), err)
}
currUser := ctx.User
if currUser != nil {
// if user is logged in, link account to user and update its avatar URL
updateUserProviderInfo(currUser, user.Provider, user)
if err = currUser.Update(); err != nil {
return ctx.ErrorRes(500, "Cannot update user "+cases.Title(language.English).String(user.Provider)+" id", err)
}
ctx.AddFlash(ctx.Tr("flash.auth.account-linked-oauth", cases.Title(language.English).String(user.Provider)), "success")
return ctx.RedirectTo("/settings")
}
// if user is not in database, create it
userDB, err := db.GetUserByProvider(user.UserID, user.Provider)
if err != nil {
if ctx.GetData("DisableSignup") == true {
return ctx.ErrorRes(403, ctx.Tr("error.signup-disabled"), nil)
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return ctx.ErrorRes(500, "Cannot get user", err)
}
if user.NickName == "" {
user.NickName = strings.Split(user.Email, "@")[0]
}
userDB = &db.User{
Username: user.NickName,
Email: user.Email,
MD5Hash: fmt.Sprintf("%x", md5.Sum([]byte(strings.ToLower(strings.TrimSpace(user.Email))))),
}
// set provider id and avatar URL
updateUserProviderInfo(userDB, user.Provider, user)
if err = userDB.Create(); err != nil {
if db.IsUniqueConstraintViolation(err) {
ctx.AddFlash(ctx.Tr("flash.auth.username-exists"), "error")
return ctx.RedirectTo("/login")
}
return ctx.ErrorRes(500, "Cannot create user", err)
}
if userDB.ID == 1 {
if err = userDB.SetAdmin(); err != nil {
return ctx.ErrorRes(500, "Cannot set user admin", err)
}
}
var resp *http.Response
switch user.Provider {
case GitHubProvider:
resp, err = http.Get("https://github.com/" + user.NickName + ".keys")
case GitLabProvider:
resp, err = http.Get(urlJoin(config.C.GitlabUrl, user.NickName+".keys"))
case GiteaProvider:
resp, err = http.Get(urlJoin(config.C.GiteaUrl, user.NickName+".keys"))
case OpenIDConnect:
err = errors.New("cannot get keys from OIDC provider")
}
if err == nil {
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
ctx.AddFlash(ctx.Tr("flash.auth.user-sshkeys-not-retrievable"), "error")
log.Error().Err(err).Msg("Could not get user keys")
}
keys := strings.Split(string(body), "\n")
if len(keys[len(keys)-1]) == 0 {
keys = keys[:len(keys)-1]
}
for _, key := range keys {
sshKey := db.SSHKey{
Title: "Added from " + user.Provider,
Content: key,
User: *userDB,
}
if err = sshKey.Create(); err != nil {
ctx.AddFlash(ctx.Tr("flash.auth.user-sshkeys-not-created"), "error")
log.Error().Err(err).Msg("Could not create ssh key")
}
}
}
}
sess := ctx.GetSession()
sess.Values["user"] = userDB.ID
ctx.SaveSession(sess)
ctx.DeleteCsrfCookie()
return ctx.RedirectTo("/")
}
func OauthUnlink(ctx *context.Context) error {
provider := ctx.Param("provider")
currUser := ctx.User
// Map each provider to a function that checks the relevant ID in currUser
providerIDCheckMap := map[string]func() bool{
GitHubProvider: func() bool { return currUser.GithubID != "" },
GitLabProvider: func() bool { return currUser.GitlabID != "" },
GiteaProvider: func() bool { return currUser.GiteaID != "" },
OpenIDConnect: func() bool { return currUser.OIDCID != "" },
}
if checkFunc, exists := providerIDCheckMap[provider]; exists && checkFunc() {
if err := currUser.DeleteProviderID(provider); err != nil {
return ctx.ErrorRes(500, "Cannot unlink account from "+cases.Title(language.English).String(provider), err)
}
ctx.AddFlash(ctx.Tr("flash.auth.account-unlinked-oauth", cases.Title(language.English).String(provider)), "success")
return ctx.RedirectTo("/settings")
}
return ctx.RedirectTo("/settings")
}
func updateUserProviderInfo(userDB *db.User, provider string, user goth.User) {
userDB.AvatarURL = getAvatarUrlFromProvider(provider, user.UserID)
switch provider {
case GitHubProvider:
userDB.GithubID = user.UserID
case GitLabProvider:
userDB.GitlabID = user.UserID
case GiteaProvider:
userDB.GiteaID = user.UserID
case OpenIDConnect:
userDB.OIDCID = user.UserID
userDB.AvatarURL = user.AvatarURL
}
}
func getAvatarUrlFromProvider(provider string, identifier string) string {
switch provider {
case GitHubProvider:
return "https://avatars.githubusercontent.com/u/" + identifier + "?v=4"
case GitLabProvider:
return urlJoin(config.C.GitlabUrl, "/uploads/-/system/user/avatar/", identifier, "/avatar.png") + "?width=400"
case GiteaProvider:
resp, err := http.Get(urlJoin(config.C.GiteaUrl, "/api/v1/users/", identifier))
if err != nil {
log.Error().Err(err).Msg("Cannot get user from Gitea")
return ""
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Error().Err(err).Msg("Cannot read Gitea response body")
return ""
}
var result map[string]interface{}
err = gojson.Unmarshal(body, &result)
if err != nil {
log.Error().Err(err).Msg("Cannot unmarshal Gitea response body")
return ""
}
field, ok := result["avatar_url"]
if !ok {
log.Error().Msg("Field 'avatar_url' not found in Gitea JSON response")
return ""
}
return field.(string)
}
return ""
}
func urlJoin(base string, elem ...string) string {
joined, err := url.JoinPath(base, elem...)
if err != nil {
log.Error().Err(err).Msg("Cannot join url")
}
return joined
}