diff --git a/internal/web/server/util.go b/internal/web/server/handler.go similarity index 83% rename from internal/web/server/util.go rename to internal/web/server/handler.go index 6e2ce1f..60321aa 100644 --- a/internal/web/server/util.go +++ b/internal/web/server/handler.go @@ -32,9 +32,7 @@ func (h Handler) toEchoHandler() echo.HandlerFunc { } } -// Chain applies middleware to a handler without conversion to echo types -func Chain(h Handler, middleware ...Middleware) Handler { - // Apply middleware in reverse order +func chain(h Handler, middleware ...Middleware) Handler { for i := len(middleware) - 1; i >= 0; i-- { h = middleware[i](h) } diff --git a/internal/web/server/middlewares.go b/internal/web/server/middlewares.go index 3999069..8620d53 100644 --- a/internal/web/server/middlewares.go +++ b/internal/web/server/middlewares.go @@ -1,14 +1,17 @@ package server import ( + "errors" "fmt" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "github.com/rs/zerolog/log" + "github.com/thomiceli/opengist/internal/auth" "github.com/thomiceli/opengist/internal/config" "github.com/thomiceli/opengist/internal/db" "github.com/thomiceli/opengist/internal/i18n" "github.com/thomiceli/opengist/internal/web/context" + "github.com/thomiceli/opengist/internal/web/handler" "golang.org/x/text/cases" "golang.org/x/text/language" "html/template" @@ -27,7 +30,7 @@ func (s *Server) useCustomContext() { }) } -func (s *Server) RegisterMiddlewares() { +func (s *Server) registerMiddlewares() { s.echo.Use(Middleware(dataInit).ToEcho()) s.echo.Use(Middleware(locale).ToEcho()) @@ -61,6 +64,28 @@ func (s *Server) RegisterMiddlewares() { } } +func (s *Server) errorHandler(err error, ctx echo.Context) { + var httpErr *echo.HTTPError + if errors.As(err, &httpErr) { + acceptJson := strings.Contains(ctx.Request().Header.Get("Accept"), "application/json") + data := ctx.Request().Context().Value("data").(echo.Map) + data["error"] = err + if acceptJson { + if err := ctx.JSON(httpErr.Code, httpErr); err != nil { + log.Fatal().Err(err).Send() + } + return + } + + if err := ctx.Render(httpErr.Code, "error", data); err != nil { + log.Fatal().Err(err).Send() + } + return + } + + log.Fatal().Err(err).Send() +} + func dataInit(next Handler) Handler { return func(ctx *context.OGContext) error { ctx.SetData("loadStartTime", time.Now()) @@ -96,6 +121,77 @@ func dataInit(next Handler) Handler { } } +func writePermission(next Handler) Handler { + return func(ctx *context.OGContext) error { + gist := ctx.GetData("gist") + user := ctx.User + if !gist.(*db.Gist).CanWrite(user) { + return ctx.RedirectTo("/" + gist.(*db.Gist).User.Username + "/" + gist.(*db.Gist).Identifier()) + } + return next(ctx) + } +} + +func adminPermission(next Handler) Handler { + return func(ctx *context.OGContext) error { + user := ctx.User + if user == nil || !user.IsAdmin { + return ctx.NotFound("User not found") + } + return next(ctx) + } +} + +func logged(next Handler) Handler { + return func(ctx *context.OGContext) error { + user := ctx.User + if user != nil { + return next(ctx) + } + return ctx.RedirectTo("/all") + } +} + +func inMFASession(next Handler) Handler { + return func(ctx *context.OGContext) error { + sess := ctx.GetSession() + _, ok := sess.Values["mfaID"].(uint) + if !ok { + return ctx.ErrorRes(400, ctx.Tr("error.not-in-mfa-session"), nil) + } + return next(ctx) + } +} + +func makeCheckRequireLogin(isSingleGistAccess bool) Middleware { + return func(next Handler) Handler { + return func(ctx *context.OGContext) error { + if user := ctx.User; user != nil { + return next(ctx) + } + + allow, err := auth.ShouldAllowUnauthenticatedGistAccess(handler.ContextAuthInfo{Context: ctx}, isSingleGistAccess) + if err != nil { + log.Fatal().Err(err).Msg("Failed to check if unauthenticated access is allowed") + } + + if !allow { + ctx.AddFlash(ctx.Tr("flash.auth.must-be-logged-in"), "error") + return ctx.RedirectTo("/login") + } + return next(ctx) + } + } +} + +func checkRequireLogin(next Handler) Handler { + return makeCheckRequireLogin(false)(next) +} + +func noRouteFound(ctx *context.OGContext) error { + return ctx.NotFound("Page not found") +} + func locale(next Handler) Handler { return func(ctx *context.OGContext) error { // Check URL arguments diff --git a/internal/web/server/renderer.go b/internal/web/server/renderer.go index 0cfda49..51a2d0c 100644 --- a/internal/web/server/renderer.go +++ b/internal/web/server/renderer.go @@ -4,6 +4,7 @@ import ( gojson "encoding/json" "errors" "fmt" + "github.com/labstack/echo/v4" "github.com/rs/zerolog/log" "github.com/thomiceli/opengist/internal/config" "github.com/thomiceli/opengist/internal/db" @@ -25,6 +26,14 @@ import ( "time" ) +type Template struct { + templates *template.Template +} + +func (t *Template) Render(w io.Writer, name string, data interface{}, _ echo.Context) error { + return t.templates.ExecuteTemplate(w, name, data) +} + var re = regexp.MustCompile("[^a-z0-9]+") func (s *Server) setFuncMap() { @@ -189,7 +198,7 @@ func (s *Server) setFuncMap() { } } -func parseManifestEntries() { +func (s *Server) parseManifestEntries() { file, err := public.Files.Open("manifest.json") if err != nil { log.Fatal().Err(err).Msg("Failed to open manifest.json") diff --git a/internal/web/server/router.go b/internal/web/server/router.go index 165da14..f963c07 100644 --- a/internal/web/server/router.go +++ b/internal/web/server/router.go @@ -15,12 +15,10 @@ import ( "time" ) -func (s *Server) setupRoutes() { +func (s *Server) registerRoutes() { r := NewRouter(s.echo.Group("")) - // Web based routes { - r.GET("/", handler.Create, logged) r.POST("/", handler.ProcessCreate, logged) r.POST("/preview", handler.Preview, logged) @@ -45,38 +43,42 @@ func (s *Server) setupRoutes() { r.GET("/mfa", handler.Mfa, inMFASession) r.POST("/mfa/totp/assertion", handler.AssertTotp, inMFASession) - r.GET("/settings", handler.UserSettings, logged) - r.POST("/settings/email", handler.EmailProcess, logged) - r.DELETE("/settings/account", handler.AccountDeleteProcess, logged) - r.POST("/settings/ssh-keys", handler.SshKeysProcess, logged) - r.DELETE("/settings/ssh-keys/:id", handler.SshKeysDelete, logged) - r.DELETE("/settings/passkeys/:id", handler.PasskeyDelete, logged) - r.PUT("/settings/password", handler.PasswordProcess, logged) - r.PUT("/settings/username", handler.UsernameProcess, logged) - r.GET("/settings/totp/generate", handler.BeginTotp, logged) - r.POST("/settings/totp/generate", handler.FinishTotp, logged) - r.DELETE("/settings/totp", handler.DisableTotp, logged) - r.POST("/settings/totp/regenerate", handler.RegenerateTotpRecoveryCodes, logged) - - g2 := r.SubGroup("/admin-panel") + sA := r.SubGroup("/settings") { - g2.Use(adminPermission) - g2.GET("", handler.AdminIndex) - g2.GET("/users", handler.AdminUsers) - g2.POST("/users/:user/delete", handler.AdminUserDelete) - g2.GET("/gists", handler.AdminGists) - g2.POST("/gists/:gist/delete", handler.AdminGistDelete) - g2.GET("/invitations", handler.AdminInvitations) - g2.POST("/invitations", handler.AdminInvitationsCreate) - g2.POST("/invitations/:id/delete", handler.AdminInvitationsDelete) - g2.POST("/sync-fs", handler.AdminSyncReposFromFS) - g2.POST("/sync-db", handler.AdminSyncReposFromDB) - g2.POST("/gc-repos", handler.AdminGcRepos) - g2.POST("/sync-previews", handler.AdminSyncGistPreviews) - g2.POST("/reset-hooks", handler.AdminResetHooks) - g2.POST("/index-gists", handler.AdminIndexGists) - g2.GET("/configuration", handler.AdminConfig) - g2.PUT("/set-config", handler.AdminSetConfig) + sA.Use(logged) + sA.GET("", handler.UserSettings) + sA.POST("/email", handler.EmailProcess) + sA.DELETE("/account", handler.AccountDeleteProcess) + sA.POST("/ssh-keys", handler.SshKeysProcess) + sA.DELETE("/ssh-keys/:id", handler.SshKeysDelete) + sA.DELETE("/passkeys/:id", handler.PasskeyDelete) + sA.PUT("/password", handler.PasswordProcess) + sA.PUT("/username", handler.UsernameProcess) + sA.GET("/totp/generate", handler.BeginTotp) + sA.POST("/totp/generate", handler.FinishTotp) + sA.DELETE("/totp", handler.DisableTotp) + sA.POST("/totp/regenerate", handler.RegenerateTotpRecoveryCodes) + } + + sB := r.SubGroup("/admin-panel") + { + sB.Use(adminPermission) + sB.GET("", handler.AdminIndex) + sB.GET("/users", handler.AdminUsers) + sB.POST("/users/:user/delete", handler.AdminUserDelete) + sB.GET("/gists", handler.AdminGists) + sB.POST("/gists/:gist/delete", handler.AdminGistDelete) + sB.GET("/invitations", handler.AdminInvitations) + sB.POST("/invitations", handler.AdminInvitationsCreate) + sB.POST("/invitations/:id/delete", handler.AdminInvitationsDelete) + sB.POST("/sync-fs", handler.AdminSyncReposFromFS) + sB.POST("/sync-db", handler.AdminSyncReposFromDB) + sB.POST("/gc-repos", handler.AdminGcRepos) + sB.POST("/sync-previews", handler.AdminSyncGistPreviews) + sB.POST("/reset-hooks", handler.AdminResetHooks) + sB.POST("/index-gists", handler.AdminIndexGists) + sB.GET("/configuration", handler.AdminConfig) + sB.PUT("/set-config", handler.AdminSetConfig) } if config.C.HttpGit { @@ -95,24 +97,24 @@ func (s *Server) setupRoutes() { r.GET("/:user/liked", handler.AllGists, checkRequireLogin) r.GET("/:user/forked", handler.AllGists, checkRequireLogin) - g3 := r.SubGroup("/:user/:gistname") + sC := r.SubGroup("/:user/:gistname") { - g3.Use(makeCheckRequireLogin(true), GistInit) - g3.GET("", handler.GistIndex) - g3.GET("/rev/:revision", handler.GistIndex) - g3.GET("/revisions", handler.Revisions) - g3.GET("/archive/:revision", handler.DownloadZip) - g3.POST("/visibility", handler.EditVisibility, logged, writePermission) - g3.POST("/delete", handler.DeleteGist, logged, writePermission) - g3.GET("/raw/:revision/:file", handler.RawFile) - g3.GET("/download/:revision/:file", handler.DownloadFile) - g3.GET("/edit", handler.Edit, logged, writePermission) - g3.POST("/edit", handler.ProcessCreate, logged, writePermission) - g3.POST("/like", handler.Like, logged) - g3.GET("/likes", handler.Likes, checkRequireLogin) - g3.POST("/fork", handler.Fork, logged) - g3.GET("/forks", handler.Forks, checkRequireLogin) - g3.PUT("/checkbox", handler.Checkbox, logged, writePermission) + sC.Use(makeCheckRequireLogin(true), GistInit) + sC.GET("", handler.GistIndex) + sC.GET("/rev/:revision", handler.GistIndex) + sC.GET("/revisions", handler.Revisions) + sC.GET("/archive/:revision", handler.DownloadZip) + sC.POST("/visibility", handler.EditVisibility, logged, writePermission) + sC.POST("/delete", handler.DeleteGist, logged, writePermission) + sC.GET("/raw/:revision/:file", handler.RawFile) + sC.GET("/download/:revision/:file", handler.DownloadFile) + sC.GET("/edit", handler.Edit, logged, writePermission) + sC.POST("/edit", handler.ProcessCreate, logged, writePermission) + sC.POST("/like", handler.Like, logged) + sC.GET("/likes", handler.Likes, checkRequireLogin) + sC.POST("/fork", handler.Fork, logged) + sC.GET("/forks", handler.Forks, checkRequireLogin) + sC.PUT("/checkbox", handler.Checkbox, logged, writePermission) } } @@ -161,7 +163,7 @@ func (r *Router) SubGroup(prefix string, m ...Middleware) *Router { for i, mw := range m { mw := mw // capture for closure echoMiddleware[i] = func(next echo.HandlerFunc) echo.HandlerFunc { - return Chain(func(c *context.OGContext) error { + return chain(func(c *context.OGContext) error { return next(c) }, mw).toEchoHandler() } @@ -169,37 +171,35 @@ func (r *Router) SubGroup(prefix string, m ...Middleware) *Router { return NewRouter(r.Group.Group(prefix, echoMiddleware...)) } -// Route registration methods func (r *Router) GET(path string, h Handler, m ...Middleware) { - r.Group.GET(path, Chain(h, m...).toEchoHandler()) + r.Group.GET(path, chain(h, m...).toEchoHandler()) } func (r *Router) POST(path string, h Handler, m ...Middleware) { - r.Group.POST(path, Chain(h, m...).toEchoHandler()) + r.Group.POST(path, chain(h, m...).toEchoHandler()) } func (r *Router) PUT(path string, h Handler, m ...Middleware) { - r.Group.PUT(path, Chain(h, m...).toEchoHandler()) + r.Group.PUT(path, chain(h, m...).toEchoHandler()) } func (r *Router) DELETE(path string, h Handler, m ...Middleware) { - r.Group.DELETE(path, Chain(h, m...).toEchoHandler()) + r.Group.DELETE(path, chain(h, m...).toEchoHandler()) } func (r *Router) PATCH(path string, h Handler, m ...Middleware) { - r.Group.PATCH(path, Chain(h, m...).toEchoHandler()) + r.Group.PATCH(path, chain(h, m...).toEchoHandler()) } func (r *Router) Any(path string, h Handler, m ...Middleware) { - r.Group.Any(path, Chain(h, m...).toEchoHandler()) + r.Group.Any(path, chain(h, m...).toEchoHandler()) } -// Use registers middleware for the entire router group func (r *Router) Use(middleware ...Middleware) { for _, m := range middleware { m := m // capture for closure r.Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc { - return Chain(func(c *context.OGContext) error { + return chain(func(c *context.OGContext) error { return next(c) }, m).toEchoHandler() }) diff --git a/internal/web/server/server.go b/internal/web/server/server.go index 987377b..8019adf 100644 --- a/internal/web/server/server.go +++ b/internal/web/server/server.go @@ -3,30 +3,16 @@ package server import ( "errors" "github.com/thomiceli/opengist/internal/utils" - "github.com/thomiceli/opengist/internal/web/context" - "github.com/thomiceli/opengist/internal/web/handler" - "html/template" - "io" "net/http" "strings" "github.com/gorilla/sessions" "github.com/labstack/echo/v4" "github.com/rs/zerolog/log" - "github.com/thomiceli/opengist/internal/auth" "github.com/thomiceli/opengist/internal/config" - "github.com/thomiceli/opengist/internal/db" "github.com/thomiceli/opengist/internal/i18n" ) -type Template struct { - templates *template.Template -} - -func (t *Template) Render(w io.Writer, name string, data interface{}, _ echo.Context) error { - return t.templates.ExecuteTemplate(w, name, data) -} - type Server struct { echo *echo.Echo flashStore *sessions.CookieStore // session store for flash messages @@ -50,17 +36,17 @@ func NewServer(isDev bool, sessionsPath string, ignoreCsrf bool) *Server { log.Fatal().Err(err).Msg("Failed to load locales") } - s.RegisterMiddlewares() + s.registerMiddlewares() s.setFuncMap() s.echo.HTTPErrorHandler = s.errorHandler e.Validator = utils.NewValidator() if !s.dev { - parseManifestEntries() + s.parseManifestEntries() } - s.setupRoutes() + s.registerRoutes() return s } @@ -83,96 +69,3 @@ func (s *Server) Stop() { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.echo.ServeHTTP(w, r) } - -func writePermission(next Handler) Handler { - return func(ctx *context.OGContext) error { - gist := ctx.GetData("gist") - user := ctx.User - if !gist.(*db.Gist).CanWrite(user) { - return ctx.RedirectTo("/" + gist.(*db.Gist).User.Username + "/" + gist.(*db.Gist).Identifier()) - } - return next(ctx) - } -} - -func adminPermission(next Handler) Handler { - return func(ctx *context.OGContext) error { - user := ctx.User - if user == nil || !user.IsAdmin { - return ctx.NotFound("User not found") - } - return next(ctx) - } -} - -func logged(next Handler) Handler { - return func(ctx *context.OGContext) error { - user := ctx.User - if user != nil { - return next(ctx) - } - return ctx.RedirectTo("/all") - } -} - -func inMFASession(next Handler) Handler { - return func(ctx *context.OGContext) error { - sess := ctx.GetSession() - _, ok := sess.Values["mfaID"].(uint) - if !ok { - return ctx.ErrorRes(400, ctx.Tr("error.not-in-mfa-session"), nil) - } - return next(ctx) - } -} - -func makeCheckRequireLogin(isSingleGistAccess bool) Middleware { - return func(next Handler) Handler { - return func(ctx *context.OGContext) error { - if user := ctx.User; user != nil { - return next(ctx) - } - - allow, err := auth.ShouldAllowUnauthenticatedGistAccess(handler.ContextAuthInfo{Context: ctx}, isSingleGistAccess) - if err != nil { - log.Fatal().Err(err).Msg("Failed to check if unauthenticated access is allowed") - } - - if !allow { - ctx.AddFlash(ctx.Tr("flash.auth.must-be-logged-in"), "error") - return ctx.RedirectTo("/login") - } - return next(ctx) - } - } -} - -func checkRequireLogin(next Handler) Handler { - return makeCheckRequireLogin(false)(next) -} - -func noRouteFound(ctx *context.OGContext) error { - return ctx.NotFound("Page not found") -} - -func (s *Server) errorHandler(err error, ctx echo.Context) { - var httpErr *echo.HTTPError - if errors.As(err, &httpErr) { - acceptJson := strings.Contains(ctx.Request().Header.Get("Accept"), "application/json") - data := ctx.Request().Context().Value("data").(echo.Map) - data["error"] = err - if acceptJson { - if err := ctx.JSON(httpErr.Code, httpErr); err != nil { - log.Fatal().Err(err).Send() - } - return - } - - if err := ctx.Render(httpErr.Code, "error", data); err != nil { - log.Fatal().Err(err).Send() - } - return - } - - log.Fatal().Err(err).Send() -}