Skip to content

feat: Add GIT_COMMITTER information to agent env vars #1171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ import (
"golang.org/x/xerrors"
)

type Options struct {
EnvironmentVariables map[string]string
StartupScript string
type Metadata struct {
OwnerEmail string `json:"owner_email"`
OwnerUsername string `json:"owner_username"`
EnvironmentVariables map[string]string `json:"environment_variables"`
StartupScript string `json:"startup_script"`
}

type Dialer func(ctx context.Context, logger slog.Logger) (*Options, *peerbroker.Listener, error)
type Dialer func(ctx context.Context, logger slog.Logger) (Metadata, *peerbroker.Listener, error)

func New(dialer Dialer, logger slog.Logger) io.Closer {
ctx, cancelFunc := context.WithCancel(context.Background())
Expand All @@ -62,14 +64,16 @@ type agent struct {
closed chan struct{}

// Environment variables sent by Coder to inject for shell sessions.
// This is atomic because values can change after reconnect.
// These are atomic because values can change after reconnect.
envVars atomic.Value
ownerEmail atomic.String
ownerUsername atomic.String
startupScript atomic.Bool
sshServer *ssh.Server
}

func (a *agent) run(ctx context.Context) {
var options *Options
var options Metadata
var peerListener *peerbroker.Listener
var err error
// An exponential back-off occurs when the connection is failing to dial.
Expand All @@ -95,6 +99,8 @@ func (a *agent) run(ctx context.Context) {
default:
}
a.envVars.Store(options.EnvironmentVariables)
a.ownerEmail.Store(options.OwnerEmail)
a.ownerUsername.Store(options.OwnerUsername)

if a.startupScript.CAS(false, true) {
// The startup script has not ran yet!
Expand Down Expand Up @@ -303,8 +309,20 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
}
cmd := exec.CommandContext(session.Context(), shell, caller, command)
cmd.Env = append(os.Environ(), session.Environ()...)
executablePath, err := os.Executable()
if err != nil {
return xerrors.Errorf("getting os executable: %w", err)
}
// Git on Windows resolves with UNIX-style paths.
// If using backslashes, it's unable to find the executable.
executablePath = strings.ReplaceAll(executablePath, "\\", "/")
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, executablePath))
// These prevent the user from having to specify _anything_ to successfully commit.
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_COMMITTER_EMAIL=%s`, a.ownerEmail.Load()))
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_COMMITTER_NAME=%s`, a.ownerUsername.Load()))

// Load environment variables passed via the agent.
// These should override all variables we manually specify.
envVars := a.envVars.Load()
if envVars != nil {
envVarMap, ok := envVars.(map[string]string)
Expand All @@ -315,15 +333,6 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
}
}

executablePath, err := os.Executable()
if err != nil {
return xerrors.Errorf("getting os executable: %w", err)
}
// Git on Windows resolves with UNIX-style paths.
// If using backslashes, it's unable to find the executable.
executablePath = strings.ReplaceAll(executablePath, "\\", "/")
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, executablePath))

sshPty, windowSize, isPty := session.Pty()
if isPty {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
Expand Down
23 changes: 10 additions & 13 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestAgent(t *testing.T) {
t.Parallel()
t.Run("SessionExec", func(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, nil)
session := setupSSHSession(t, agent.Metadata{})

command := "echo test"
if runtime.GOOS == "windows" {
Expand All @@ -53,7 +53,7 @@ func TestAgent(t *testing.T) {

t.Run("GitSSH", func(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, nil)
session := setupSSHSession(t, agent.Metadata{})
command := "sh -c 'echo $GIT_SSH_COMMAND'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %GIT_SSH_COMMAND%"
Expand All @@ -71,7 +71,7 @@ func TestAgent(t *testing.T) {
// it seems like it could be either.
t.Skip("ConPTY appears to be inconsistent on Windows.")
}
session := setupSSHSession(t, nil)
session := setupSSHSession(t, agent.Metadata{})
command := "bash"
if runtime.GOOS == "windows" {
command = "cmd.exe"
Expand Down Expand Up @@ -131,7 +131,7 @@ func TestAgent(t *testing.T) {

t.Run("SFTP", func(t *testing.T) {
t.Parallel()
sshClient, err := setupAgent(t, nil).SSHClient()
sshClient, err := setupAgent(t, agent.Metadata{}).SSHClient()
require.NoError(t, err)
client, err := sftp.NewClient(sshClient)
require.NoError(t, err)
Expand All @@ -148,7 +148,7 @@ func TestAgent(t *testing.T) {
t.Parallel()
key := "EXAMPLE"
value := "value"
session := setupSSHSession(t, &agent.Options{
session := setupSSHSession(t, agent.Metadata{
EnvironmentVariables: map[string]string{
key: value,
},
Expand All @@ -166,7 +166,7 @@ func TestAgent(t *testing.T) {
t.Parallel()
tempPath := filepath.Join(os.TempDir(), "content.txt")
content := "somethingnice"
setupAgent(t, &agent.Options{
setupAgent(t, agent.Metadata{
StartupScript: "echo " + content + " > " + tempPath,
})
var gotContent string
Expand All @@ -191,7 +191,7 @@ func TestAgent(t *testing.T) {
}

func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
agentConn := setupAgent(t, nil)
agentConn := setupAgent(t, agent.Metadata{})
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
go func() {
Expand Down Expand Up @@ -219,20 +219,17 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
return exec.Command("ssh", args...)
}

func setupSSHSession(t *testing.T, options *agent.Options) *ssh.Session {
func setupSSHSession(t *testing.T, options agent.Metadata) *ssh.Session {
sshClient, err := setupAgent(t, options).SSHClient()
require.NoError(t, err)
session, err := sshClient.NewSession()
require.NoError(t, err)
return session
}

func setupAgent(t *testing.T, options *agent.Options) *agent.Conn {
if options == nil {
options = &agent.Options{}
}
func setupAgent(t *testing.T, options agent.Metadata) *agent.Conn {
client, server := provisionersdk.TransportPipe()
closer := agent.New(func(ctx context.Context, logger slog.Logger) (*agent.Options, *peerbroker.Listener, error) {
closer := agent.New(func(ctx context.Context, logger slog.Logger) (agent.Metadata, *peerbroker.Listener, error) {
listener, err := peerbroker.Listen(server, nil)
return options, listener, err
}, slogtest.Make(t, nil).Leveled(slog.LevelDebug))
Expand Down
2 changes: 1 addition & 1 deletion coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func New(options *Options) (http.Handler, func()) {
r.Post("/google-instance-identity", api.postWorkspaceAuthGoogleInstanceIdentity)
r.Route("/me", func(r chi.Router) {
r.Use(httpmw.ExtractWorkspaceAgent(options.Database))
r.Get("/", api.workspaceAgentMe)
r.Get("/metadata", api.workspaceAgentMetadata)
r.Get("/listen", api.workspaceAgentListen)
r.Get("/gitsshkey", api.agentGitSSHKey)
r.Get("/turn", api.workspaceAgentTurn)
Expand Down
84 changes: 59 additions & 25 deletions coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"nhooyr.io/websocket"

"cdr.dev/slog"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
Expand All @@ -25,8 +26,8 @@ import (
)

func (api *api) workspaceAgent(rw http.ResponseWriter, r *http.Request) {
agent := httpmw.WorkspaceAgentParam(r)
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
workspaceAgent := httpmw.WorkspaceAgentParam(r)
apiAgent, err := convertWorkspaceAgent(workspaceAgent, api.AgentConnectionUpdateFrequency)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("convert workspace agent: %s", err),
Expand All @@ -43,8 +44,8 @@ func (api *api) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()

agent := httpmw.WorkspaceAgentParam(r)
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
workspaceAgent := httpmw.WorkspaceAgentParam(r)
apiAgent, err := convertWorkspaceAgent(workspaceAgent, api.AgentConnectionUpdateFrequency)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("convert workspace agent: %s", err),
Expand Down Expand Up @@ -78,7 +79,7 @@ func (api *api) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
return
}
err = peerbroker.ProxyListen(r.Context(), session, peerbroker.ProxyOptions{
ChannelID: agent.ID.String(),
ChannelID: workspaceAgent.ID.String(),
Logger: api.Logger.Named("peerbroker-proxy-dial"),
Pubsub: api.Pubsub,
})
Expand All @@ -88,16 +89,49 @@ func (api *api) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
}
}

func (api *api) workspaceAgentMe(rw http.ResponseWriter, r *http.Request) {
agent := httpmw.WorkspaceAgent(r)
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
func (api *api) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) {
workspaceAgent := httpmw.WorkspaceAgent(r)
apiAgent, err := convertWorkspaceAgent(workspaceAgent, api.AgentConnectionUpdateFrequency)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("convert workspace agent: %s", err),
})
return
}
httpapi.Write(rw, http.StatusOK, apiAgent)
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace resource: %s", err),
})
return
}
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace build: %s", err),
})
return
}
workspace, err := api.Database.GetWorkspaceByID(r.Context(), build.WorkspaceID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace build: %s", err),
})
return
}
owner, err := api.Database.GetUserByID(r.Context(), workspace.OwnerID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace build: %s", err),
})
return
}
httpapi.Write(rw, http.StatusOK, agent.Metadata{
OwnerEmail: owner.Email,
OwnerUsername: owner.Username,
EnvironmentVariables: apiAgent.EnvironmentVariables,
StartupScript: apiAgent.StartupScript,
})
}

func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
Expand All @@ -106,7 +140,7 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()

agent := httpmw.WorkspaceAgent(r)
workspaceAgent := httpmw.WorkspaceAgent(r)
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
Expand All @@ -116,7 +150,7 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
})
return
}
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), agent.ResourceID)
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("accept websocket: %s", err),
Expand All @@ -135,7 +169,7 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
return
}
closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{
ChannelID: agent.ID.String(),
ChannelID: workspaceAgent.ID.String(),
Pubsub: api.Pubsub,
Logger: api.Logger.Named("peerbroker-proxy-listen"),
})
Expand All @@ -144,7 +178,7 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
return
}
defer closer.Close()
firstConnectedAt := agent.FirstConnectedAt
firstConnectedAt := workspaceAgent.FirstConnectedAt
if !firstConnectedAt.Valid {
firstConnectedAt = sql.NullTime{
Time: database.Now(),
Expand All @@ -155,10 +189,10 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
Time: database.Now(),
Valid: true,
}
disconnectedAt := agent.DisconnectedAt
disconnectedAt := workspaceAgent.DisconnectedAt
updateConnectionTimes := func() error {
err = api.Database.UpdateWorkspaceAgentConnectionByID(r.Context(), database.UpdateWorkspaceAgentConnectionByIDParams{
ID: agent.ID,
ID: workspaceAgent.ID,
FirstConnectedAt: firstConnectedAt,
LastConnectedAt: lastConnectedAt,
DisconnectedAt: disconnectedAt,
Expand Down Expand Up @@ -205,7 +239,7 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
return
}

api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", agent))
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))

ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
defer ticker.Stop()
Expand Down Expand Up @@ -294,7 +328,7 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal: %w", err)
}
}
agent := codersdk.WorkspaceAgent{
workspaceAgent := codersdk.WorkspaceAgent{
ID: dbAgent.ID,
CreatedAt: dbAgent.CreatedAt,
UpdatedAt: dbAgent.UpdatedAt,
Expand All @@ -307,31 +341,31 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency
EnvironmentVariables: envs,
}
if dbAgent.FirstConnectedAt.Valid {
agent.FirstConnectedAt = &dbAgent.FirstConnectedAt.Time
workspaceAgent.FirstConnectedAt = &dbAgent.FirstConnectedAt.Time
}
if dbAgent.LastConnectedAt.Valid {
agent.LastConnectedAt = &dbAgent.LastConnectedAt.Time
workspaceAgent.LastConnectedAt = &dbAgent.LastConnectedAt.Time
}
if dbAgent.DisconnectedAt.Valid {
agent.DisconnectedAt = &dbAgent.DisconnectedAt.Time
workspaceAgent.DisconnectedAt = &dbAgent.DisconnectedAt.Time
}
switch {
case !dbAgent.FirstConnectedAt.Valid:
// If the agent never connected, it's waiting for the compute
// to start up.
agent.Status = codersdk.WorkspaceAgentConnecting
workspaceAgent.Status = codersdk.WorkspaceAgentConnecting
case dbAgent.DisconnectedAt.Time.After(dbAgent.LastConnectedAt.Time):
// If we've disconnected after our last connection, we know the
// agent is no longer connected.
agent.Status = codersdk.WorkspaceAgentDisconnected
workspaceAgent.Status = codersdk.WorkspaceAgentDisconnected
case agentUpdateFrequency*2 >= database.Now().Sub(dbAgent.LastConnectedAt.Time):
// The connection updated it's timestamp within the update frequency.
// We multiply by two to allow for some lag.
agent.Status = codersdk.WorkspaceAgentConnected
workspaceAgent.Status = codersdk.WorkspaceAgentConnected
case database.Now().Sub(dbAgent.LastConnectedAt.Time) > agentUpdateFrequency*2:
// The connection died without updating the last connected.
agent.Status = codersdk.WorkspaceAgentDisconnected
workspaceAgent.Status = codersdk.WorkspaceAgentDisconnected
}

return agent, nil
return workspaceAgent, nil
}
2 changes: 0 additions & 2 deletions coderd/workspaceagents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ func TestWorkspaceAgentListen(t *testing.T) {
})
_, err = conn.Ping()
require.NoError(t, err)
_, err = agentClient.WorkspaceAgent(context.Background(), codersdk.Me)
require.NoError(t, err)
}

func TestWorkspaceAgentTURN(t *testing.T) {
Expand Down
Loading