Skip to content

refactor(agent/agentssh): move parsing of magic session and create type #16630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode)
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, string(agentssh.MagicSessionTypeVSCode))
defer session.Close()

command := "sh -c 'echo $" + agentssh.MagicSessionTypeEnvironmentVariable + "'"
Expand All @@ -165,7 +165,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, agentssh.MagicSessionTypeVSCode)
session.Setenv(agentssh.MagicSessionTypeEnvironmentVariable, string(agentssh.MagicSessionTypeVSCode))
defer session.Close()
stdin, err := session.StdinPipe()
require.NoError(t, err)
Expand Down
134 changes: 85 additions & 49 deletions agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/spf13/afero"
"go.uber.org/atomic"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"

"cdr.dev/slog"
Expand All @@ -42,14 +43,6 @@ const (
// unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
MagicSessionErrorCode = 229

// MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
// This is stripped from any commands being executed, and is counted towards connection stats.
MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
// MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
MagicSessionTypeVSCode = "vscode"
// MagicSessionTypeJetBrains is set in the SSH config by the JetBrains
// extension to identify itself.
MagicSessionTypeJetBrains = "jetbrains"
// MagicProcessCmdlineJetBrains is a string in a process's command line that
// uniquely identifies it as JetBrains software.
MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains"
Expand All @@ -60,6 +53,29 @@ const (
BlockedFileTransferErrorMessage = "File transfer has been disabled."
)

// MagicSessionType is a type that represents the type of session that is being
// established.
type MagicSessionType string

const (
// MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
// This is stripped from any commands being executed, and is counted towards connection stats.
MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
)

// MagicSessionType enums.
const (
// MagicSessionTypeUnknown means the session type could not be determined.
MagicSessionTypeUnknown MagicSessionType = "unknown"
// MagicSessionTypeSSH is the default session type.
MagicSessionTypeSSH MagicSessionType = "ssh"
// MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
MagicSessionTypeVSCode MagicSessionType = "vscode"
// MagicSessionTypeJetBrains is set in the SSH config by the JetBrains
// extension to identify itself.
MagicSessionTypeJetBrains MagicSessionType = "jetbrains"
)

// BlockedFileTransferCommands contains a list of restricted file transfer commands.
var BlockedFileTransferCommands = []string{"nc", "rsync", "scp", "sftp"}

Expand Down Expand Up @@ -255,14 +271,42 @@ func (s *Server) ConnStats() ConnStats {
}
}

func extractMagicSessionType(env []string) (magicType MagicSessionType, rawType string, filteredEnv []string) {
for _, kv := range env {
if !strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable) {
continue
}

rawType = strings.TrimPrefix(kv, MagicSessionTypeEnvironmentVariable+"=")
// Keep going, we'll use the last instance of the env.
}

// Always force lowercase checking to be case-insensitive.
switch MagicSessionType(strings.ToLower(rawType)) {
case MagicSessionTypeVSCode:
magicType = MagicSessionTypeVSCode
case MagicSessionTypeJetBrains:
magicType = MagicSessionTypeJetBrains
case "", MagicSessionTypeSSH:
magicType = MagicSessionTypeSSH
default:
magicType = MagicSessionTypeUnknown
}

return magicType, rawType, slices.DeleteFunc(env, func(kv string) bool {
return strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable+"=")
})
}

Comment on lines +274 to +300
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: extractMagicSessionType(env []string) (MagicSessionType, filteredEnv, error) and return error in case of unknown magic session type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of like the suggestion, but I don't like returning valid data from a function that returns an error. Essentially we'd still want to filter the env in the case of an error parsing the magic session type. (We only warn log currently if it's invalid.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair!

func (s *Server) sessionHandler(session ssh.Session) {
ctx := session.Context()
id := uuid.New()
logger := s.logger.With(
slog.F("remote_addr", session.RemoteAddr()),
slog.F("local_addr", session.LocalAddr()),
// Assigning a random uuid for each session is useful for tracking
// logs for the same ssh session.
slog.F("id", uuid.NewString()),
slog.F("id", id.String()),
)
logger.Info(ctx, "handling ssh session")

Expand All @@ -274,16 +318,21 @@ func (s *Server) sessionHandler(session ssh.Session) {
}
defer s.trackSession(session, false)

extraEnv := make([]string, 0)
x11, hasX11 := session.X11()
if hasX11 {
display, handled := s.x11Handler(session.Context(), x11)
if !handled {
_ = session.Exit(1)
logger.Error(ctx, "x11 handler failed")
return
}
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
env := session.Environ()
magicType, magicTypeRaw, env := extractMagicSessionType(env)

switch magicType {
case MagicSessionTypeVSCode:
s.connCountVSCode.Add(1)
defer s.connCountVSCode.Add(-1)
case MagicSessionTypeJetBrains:
// Do nothing here because JetBrains launches hundreds of ssh sessions.
// We instead track JetBrains in the single persistent tcp forwarding channel.
case MagicSessionTypeSSH:
s.connCountSSHSession.Add(1)
defer s.connCountSSHSession.Add(-1)
case MagicSessionTypeUnknown:
logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("raw_type", magicTypeRaw))
}

if s.fileTransferBlocked(session) {
Expand All @@ -309,7 +358,18 @@ func (s *Server) sessionHandler(session ssh.Session) {
return
}

err := s.sessionStart(logger, session, extraEnv)
x11, hasX11 := session.X11()
if hasX11 {
display, handled := s.x11Handler(session.Context(), x11)
if !handled {
_ = session.Exit(1)
logger.Error(ctx, "x11 handler failed")
return
}
env = append(env, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
}

err := s.sessionStart(logger, session, env, magicType)
var exitError *exec.ExitError
if xerrors.As(err, &exitError) {
code := exitError.ExitCode()
Expand Down Expand Up @@ -379,32 +439,8 @@ func (s *Server) fileTransferBlocked(session ssh.Session) bool {
return false
}

func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv []string) (retErr error) {
func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, env []string, magicType MagicSessionType) (retErr error) {
ctx := session.Context()
env := append(session.Environ(), extraEnv...)
var magicType string
for index, kv := range env {
if !strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable) {
continue
}
magicType = strings.ToLower(strings.TrimPrefix(kv, MagicSessionTypeEnvironmentVariable+"="))
env = append(env[:index], env[index+1:]...)
}

// Always force lowercase checking to be case-insensitive.
switch magicType {
case MagicSessionTypeVSCode:
s.connCountVSCode.Add(1)
defer s.connCountVSCode.Add(-1)
case MagicSessionTypeJetBrains:
// Do nothing here because JetBrains launches hundreds of ssh sessions.
// We instead track JetBrains in the single persistent tcp forwarding channel.
case "":
s.connCountSSHSession.Add(1)
defer s.connCountSSHSession.Add(-1)
default:
logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType))
}

magicTypeLabel := magicTypeMetricLabel(magicType)
sshPty, windowSize, isPty := session.Pty()
Expand Down Expand Up @@ -473,7 +509,7 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
}()
go func() {
for sig := range sigs {
s.handleSignal(logger, sig, cmd.Process, magicTypeLabel)
handleSignal(logger, sig, cmd.Process, s.metrics, magicTypeLabel)
}
}()
return cmd.Wait()
Expand Down Expand Up @@ -558,7 +594,7 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
sigs = nil
continue
}
s.handleSignal(logger, sig, process, magicTypeLabel)
handleSignal(logger, sig, process, s.metrics, magicTypeLabel)
case win, ok := <-windowSize:
if !ok {
windowSize = nil
Expand Down Expand Up @@ -612,15 +648,15 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
return nil
}

func (s *Server) handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signal(os.Signal) error }, magicTypeLabel string) {
func handleSignal(logger slog.Logger, ssig ssh.Signal, signaler interface{ Signal(os.Signal) error }, metrics *sshServerMetrics, magicTypeLabel string) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

ctx := context.Background()
sig := osSignalFrom(ssig)
logger = logger.With(slog.F("ssh_signal", ssig), slog.F("signal", sig.String()))
logger.Info(ctx, "received signal from client")
err := signaler.Signal(sig)
if err != nil {
logger.Warn(ctx, "signaling the process failed", slog.Error(err))
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "signal").Add(1)
metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "signal").Add(1)
}
}

Expand Down
10 changes: 5 additions & 5 deletions agent/agentssh/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ func newSSHServerMetrics(registerer prometheus.Registerer) *sshServerMetrics {
}
}

func magicTypeMetricLabel(magicType string) string {
func magicTypeMetricLabel(magicType MagicSessionType) string {
switch magicType {
case MagicSessionTypeVSCode:
case MagicSessionTypeJetBrains:
case "":
magicType = "ssh"
Comment on lines -78 to -79
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This smells like some cutover compatibility logic here.

case MagicSessionTypeSSH:
case MagicSessionTypeUnknown:
default:
magicType = "unknown"
magicType = MagicSessionTypeUnknown
}
// Always be case insensitive
return strings.ToLower(magicType)
return strings.ToLower(string(magicType))
}
Loading