Skip to content

Commit 023275c

Browse files
committed
feat: add timeout support to workspace bash tool
Change-Id: I996cbde4a50debb54a0a95ca5a067781719fa25a Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent 070178c commit 023275c

File tree

2 files changed

+320
-11
lines changed

2 files changed

+320
-11
lines changed

codersdk/toolsdk/bash.go

Lines changed: 140 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package toolsdk
22

33
import (
4+
"bytes"
45
"context"
56
"errors"
67
"fmt"
78
"io"
89
"strings"
10+
"sync"
11+
"time"
912

1013
gossh "golang.org/x/crypto/ssh"
1114
"golang.org/x/xerrors"
@@ -20,6 +23,7 @@ import (
2023
type WorkspaceBashArgs struct {
2124
Workspace string `json:"workspace"`
2225
Command string `json:"command"`
26+
TimeoutMs int `json:"timeout_ms,omitempty"`
2327
}
2428

2529
type WorkspaceBashResult struct {
@@ -43,9 +47,12 @@ The workspace parameter supports various formats:
4347
- workspace.agent (specific agent)
4448
- owner/workspace.agent
4549
50+
The timeout_ms parameter specifies the command timeout in milliseconds (defaults to 60000ms, maximum of 300000ms).
51+
If the command times out, all output captured up to that point is returned with a cancellation message.
52+
4653
Examples:
4754
- workspace: "my-workspace", command: "ls -la"
48-
- workspace: "john/dev-env", command: "git status"
55+
- workspace: "john/dev-env", command: "git status", timeout_ms: 30000
4956
- workspace: "my-workspace.main", command: "docker ps"`,
5057
Schema: aisdk.Schema{
5158
Properties: map[string]any{
@@ -57,18 +64,27 @@ Examples:
5764
"type": "string",
5865
"description": "The bash command to execute in the workspace.",
5966
},
67+
"timeout_ms": map[string]any{
68+
"type": "integer",
69+
"description": "Command timeout in milliseconds. Defaults to 60000ms (60 seconds) if not specified.",
70+
"default": 60000,
71+
"minimum": 1,
72+
},
6073
},
6174
Required: []string{"workspace", "command"},
6275
},
6376
},
64-
Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (WorkspaceBashResult, error) {
77+
Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (res WorkspaceBashResult, err error) {
6578
if args.Workspace == "" {
6679
return WorkspaceBashResult{}, xerrors.New("workspace name cannot be empty")
6780
}
6881
if args.Command == "" {
6982
return WorkspaceBashResult{}, xerrors.New("command cannot be empty")
7083
}
7184

85+
ctx, cancel := context.WithTimeoutCause(ctx, 5*time.Minute, errors.New("MCP handler timeout after 5 min"))
86+
defer cancel()
87+
7288
// Normalize workspace input to handle various formats
7389
workspaceName := NormalizeWorkspaceInput(args.Workspace)
7490

@@ -119,23 +135,41 @@ Examples:
119135
}
120136
defer session.Close()
121137

122-
// Execute command and capture output
123-
output, err := session.CombinedOutput(args.Command)
138+
// Set default timeout if not specified (60 seconds)
139+
timeoutMs := args.TimeoutMs
140+
if timeoutMs <= 0 {
141+
timeoutMs = 60000
142+
}
143+
144+
// Create context with timeout
145+
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond)
146+
defer cancel()
147+
148+
// Execute command with timeout handling
149+
output, err := executeCommandWithTimeout(ctx, session, args.Command)
124150
outputStr := strings.TrimSpace(string(output))
125151

152+
// Handle command execution results
126153
if err != nil {
127-
// Check if it's an SSH exit error to get the exit code
128-
var exitErr *gossh.ExitError
129-
if errors.As(err, &exitErr) {
154+
// Check if the command timed out
155+
if context.Cause(ctx) == context.DeadlineExceeded {
156+
outputStr += "\nCommand cancelled due to timeout"
130157
return WorkspaceBashResult{
131158
Output: outputStr,
132-
ExitCode: exitErr.ExitStatus(),
159+
ExitCode: 124,
133160
}, nil
134161
}
135-
// For other errors, return exit code 1
162+
163+
// Extract exit code from SSH error if available
164+
exitCode := 1
165+
if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) {
166+
exitCode = exitErr.ExitStatus()
167+
}
168+
169+
// For other errors, use standard timeout or generic error code
136170
return WorkspaceBashResult{
137171
Output: outputStr,
138-
ExitCode: 1,
172+
ExitCode: exitCode,
139173
}, nil
140174
}
141175

@@ -292,3 +326,99 @@ func NormalizeWorkspaceInput(input string) string {
292326

293327
return normalized
294328
}
329+
330+
// executeCommandWithTimeout executes a command with timeout support
331+
func executeCommandWithTimeout(ctx context.Context, session *gossh.Session, command string) ([]byte, error) {
332+
// Set up pipes to capture output
333+
stdoutPipe, err := session.StdoutPipe()
334+
if err != nil {
335+
return nil, xerrors.Errorf("failed to create stdout pipe: %w", err)
336+
}
337+
338+
stderrPipe, err := session.StderrPipe()
339+
if err != nil {
340+
return nil, xerrors.Errorf("failed to create stderr pipe: %w", err)
341+
}
342+
343+
// Start the command
344+
if err := session.Start(command); err != nil {
345+
return nil, xerrors.Errorf("failed to start command: %w", err)
346+
}
347+
348+
// Create a thread-safe buffer for combined output
349+
var output bytes.Buffer
350+
var mu sync.Mutex
351+
safeWriter := &syncWriter{w: &output, mu: &mu}
352+
353+
// Use io.MultiWriter to combine stdout and stderr
354+
multiWriter := io.MultiWriter(safeWriter)
355+
356+
// Channel to signal when command completes
357+
done := make(chan error, 1)
358+
359+
// Start goroutine to copy output and wait for completion
360+
go func() {
361+
// Copy stdout and stderr concurrently
362+
var wg sync.WaitGroup
363+
wg.Add(2)
364+
365+
go func() {
366+
defer wg.Done()
367+
_, _ = io.Copy(multiWriter, stdoutPipe)
368+
}()
369+
370+
go func() {
371+
defer wg.Done()
372+
_, _ = io.Copy(multiWriter, stderrPipe)
373+
}()
374+
375+
// Wait for all output to be copied
376+
wg.Wait()
377+
378+
// Wait for the command to complete
379+
done <- session.Wait()
380+
}()
381+
382+
// Wait for either completion or context cancellation
383+
select {
384+
case err := <-done:
385+
// Command completed normally
386+
return safeWriter.Bytes(), err
387+
case <-ctx.Done():
388+
// Context was cancelled (timeout or other cancellation)
389+
// Close the session to stop the command
390+
_ = session.Close()
391+
392+
// Give a brief moment to collect any remaining output
393+
timer := time.NewTimer(50 * time.Millisecond)
394+
defer timer.Stop()
395+
396+
select {
397+
case <-timer.C:
398+
// Timer expired, return what we have
399+
case err := <-done:
400+
// Command finished during grace period
401+
return safeWriter.Bytes(), err
402+
}
403+
404+
return safeWriter.Bytes(), context.Cause(ctx)
405+
}
406+
}
407+
408+
// syncWriter is a thread-safe writer
409+
type syncWriter struct {
410+
w *bytes.Buffer
411+
mu *sync.Mutex
412+
}
413+
414+
func (sw *syncWriter) Write(p []byte) (n int, err error) {
415+
sw.mu.Lock()
416+
defer sw.mu.Unlock()
417+
return sw.w.Write(p)
418+
}
419+
420+
func (sw *syncWriter) Bytes() []byte {
421+
sw.mu.Lock()
422+
defer sw.mu.Unlock()
423+
return sw.w.Bytes()
424+
}

0 commit comments

Comments
 (0)