diff --git a/internal/models/sshkey.go b/internal/models/sshkey.go index d4b9b20..2dcdabc 100644 --- a/internal/models/sshkey.go +++ b/internal/models/sshkey.go @@ -48,7 +48,7 @@ func GetSSHKeyByID(sshKeyId uint) (*SSHKey, error) { return sshKey, err } -func GetSSHKeyByContent(sshKeyContent string) (*SSHKey, error) { +func SSHKeyDoesExists(sshKeyContent string) (*SSHKey, error) { sshKey := new(SSHKey) err := db. Where("content like ?", sshKeyContent+"%"). @@ -65,9 +65,9 @@ func (sshKey *SSHKey) Delete() error { return db.Delete(&sshKey).Error } -func SSHKeyLastUsedNow(sshKeyID uint) error { +func SSHKeyLastUsedNow(sshKeyContent string) error { return db.Model(&SSHKey{}). - Where("id = ?", sshKeyID). + Where("content = ?", sshKeyContent). Update("last_used_at", time.Now().Unix()).Error } diff --git a/internal/models/user.go b/internal/models/user.go index 9b0cabe..73682e2 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -81,15 +81,14 @@ func GetUserById(userId uint) (*User, error) { return user, err } -func GetUserBySSHKeyID(sshKeyId uint) (*User, error) { - user := new(User) +func SSHKeyExistsForUser(sshKey string, userId uint) (*SSHKey, error) { + key := new(SSHKey) err := db. - Preload("SSHKeys"). - Joins("join ssh_keys on users.id = ssh_keys.user_id"). - Where("ssh_keys.id = ?", sshKeyId). - First(&user).Error + Where("content = ?", sshKey). + Where("user_id = ?", userId). + First(&key).Error - return user, err + return key, err } func GetUserByProvider(id string, provider string) (*User, error) { diff --git a/internal/ssh/git_ssh.go b/internal/ssh/git_ssh.go index 8a18381..376d336 100644 --- a/internal/ssh/git_ssh.go +++ b/internal/ssh/git_ssh.go @@ -12,7 +12,7 @@ import ( "strings" ) -func runGitCommand(ch ssh.Channel, gitCmd string, keyID uint, ip string) error { +func runGitCommand(ch ssh.Channel, gitCmd string, key string, ip string) error { verb, args := parseCommand(gitCmd) if !strings.HasPrefix(verb, "git-") { verb = "" @@ -43,7 +43,7 @@ func runGitCommand(ch ssh.Channel, gitCmd string, keyID uint, ip string) error { } if verb == "receive-pack" || requireLogin == "1" { - user, err := models.GetUserBySSHKeyID(keyID) + pubKey, err := models.SSHKeyExistsForUser(key, gist.UserID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Warn().Msg("Invalid SSH authentication attempt from " + ip) @@ -52,15 +52,9 @@ func runGitCommand(ch ssh.Channel, gitCmd string, keyID uint, ip string) error { errorSsh("Failed to get user by SSH key id", err) return errors.New("internal server error") } - - if user.ID != gist.UserID { - log.Warn().Msg("Invalid SSH authentication attempt from " + ip) - return errors.New("unauthorized") - } + _ = models.SSHKeyLastUsedNow(pubKey.Content) } - _ = models.SSHKeyLastUsedNow(keyID) - repositoryPath := git.RepositoryPath(gist.User.Username, gist.Uuid) cmd := exec.Command("git", verb, repositoryPath) diff --git a/internal/ssh/run.go b/internal/ssh/run.go index 15c7f7b..a4edf7a 100644 --- a/internal/ssh/run.go +++ b/internal/ssh/run.go @@ -12,7 +12,6 @@ import ( "os" "os/exec" "path/filepath" - "strconv" "strings" "syscall" ) @@ -24,7 +23,8 @@ func Start() { sshConfig := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - pkey, err := models.GetSSHKeyByContent(strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key)))) + strKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key))) + _, err := models.SSHKeyDoesExists(strKey) if err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err @@ -33,7 +33,7 @@ func Start() { log.Warn().Msg("Invalid SSH authentication attempt from " + conn.RemoteAddr().String()) return nil, errors.New("unknown public key") } - return &ssh.Permissions{Extensions: map[string]string{"key-id": strconv.Itoa(int(pkey.ID))}}, nil + return &ssh.Permissions{Extensions: map[string]string{"key": strKey}}, nil }, } @@ -71,13 +71,12 @@ func listen(serverConfig *ssh.ServerConfig) { } go ssh.DiscardRequests(reqs) - keyID, _ := strconv.Atoi(sConn.Permissions.Extensions["key-id"]) - go handleConnexion(channels, uint(keyID), sConn.RemoteAddr().String()) + go handleConnexion(channels, sConn.Permissions.Extensions["key"], sConn.RemoteAddr().String()) }() } } -func handleConnexion(channels <-chan ssh.NewChannel, keyID uint, ip string) { +func handleConnexion(channels <-chan ssh.NewChannel, key string, ip string) { for channel := range channels { if channel.ChannelType() != "session" { _ = channel.Reject(ssh.UnknownChannelType, "Unknown channel type") @@ -109,7 +108,7 @@ func handleConnexion(channels <-chan ssh.NewChannel, keyID uint, ip string) { payloadCmd = payloadCmd[i:] } - if err = runGitCommand(ch, payloadCmd, keyID, ip); err != nil { + if err = runGitCommand(ch, payloadCmd, key, ip); err != nil { _, _ = ch.Stderr().Write([]byte("Opengist: " + err.Error() + "\r\n")) } _, _ = ch.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) diff --git a/internal/web/settings.go b/internal/web/settings.go index 795cb68..8317e31 100644 --- a/internal/web/settings.go +++ b/internal/web/settings.go @@ -74,11 +74,12 @@ func sshKeysProcess(ctx echo.Context) error { key.UserID = user.ID - _, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key.Content)) + pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key.Content)) if err != nil { addFlash(ctx, "Invalid SSH key", "error") return redirect(ctx, "/settings") } + key.Content = strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey))) if err := key.Create(); err != nil { return errorRes(500, "Cannot add SSH key", err) diff --git a/templates/pages/settings.html b/templates/pages/settings.html index ca6c824..3bbfebe 100644 --- a/templates/pages/settings.html +++ b/templates/pages/settings.html @@ -83,7 +83,7 @@ Add SSH Key