Skip to content

Commit 00dd127

Browse files
committed
feat: add timeout support to workspace bash tool
Change-Id: I996cbde4a50debb54a0a95ca5a067781719fa25a Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent 5c31b98 commit 00dd127

File tree

2 files changed

+321
-11
lines changed

2 files changed

+321
-11
lines changed

codersdk/toolsdk/bash.go

Lines changed: 141 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package toolsdk
22

33
import (
4+
"bytes"
45
"context"
56
"errors"
67
"fmt"
78
"io"
89
"strings"
10+
"sync"
11+
"time"
912

1013
gossh "golang.org/x/crypto/ssh"
1114
"golang.org/x/xerrors"
@@ -20,6 +23,7 @@ import (
2023
type WorkspaceBashArgs struct {
2124
Workspace string `json:"workspace"`
2225
Command string `json:"command"`
26+
TimeoutMs int `json:"timeout_ms,omitempty"`
2327
}
2428

2529
type WorkspaceBashResult struct {
@@ -43,9 +47,12 @@ The workspace parameter supports various formats:
4347
- workspace.agent (specific agent)
4448
- owner/workspace.agent
4549
50+
The timeout_ms parameter specifies the command timeout in milliseconds (defaults to 60000ms, maximum of 300000ms).
51+
If the command times out, all output captured up to that point is returned with a cancellation message.
52+
4653
Examples:
4754
- workspace: "my-workspace", command: "ls -la"
48-
- workspace: "john/dev-env", command: "git status"
55+
- workspace: "john/dev-env", command: "git status", timeout_ms: 30000
4956
- workspace: "my-workspace.main", command: "docker ps"`,
5057
Schema: aisdk.Schema{
5158
Properties: map[string]any{
@@ -57,18 +64,27 @@ Examples:
5764
"type": "string",
5865
"description": "The bash command to execute in the workspace.",
5966
},
67+
"timeout_ms": map[string]any{
68+
"type": "integer",
69+
"description": "Command timeout in milliseconds. Defaults to 60000ms (60 seconds) if not specified.",
70+
"default": 60000,
71+
"minimum": 1,
72+
},
6073
},
6174
Required: []string{"workspace", "command"},
6275
},
6376
},
64-
Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (WorkspaceBashResult, error) {
77+
Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (res WorkspaceBashResult, err error) {
6578
if args.Workspace == "" {
6679
return WorkspaceBashResult{}, xerrors.New("workspace name cannot be empty")
6780
}
6881
if args.Command == "" {
6982
return WorkspaceBashResult{}, xerrors.New("command cannot be empty")
7083
}
7184

85+
ctx, cancel := context.WithTimeoutCause(ctx, 5*time.Minute, xerrors.New("MCP handler timeout after 5 min"))
86+
defer cancel()
87+
7288
// Normalize workspace input to handle various formats
7389
workspaceName := NormalizeWorkspaceInput(args.Workspace)
7490

@@ -119,23 +135,42 @@ Examples:
119135
}
120136
defer session.Close()
121137

122-
// Execute command and capture output
123-
output, err := session.CombinedOutput(args.Command)
138+
// Set default timeout if not specified (60 seconds)
139+
timeoutMs := args.TimeoutMs
140+
if timeoutMs <= 0 {
141+
timeoutMs = 60000
142+
}
143+
144+
// Create context with timeout
145+
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond)
146+
defer cancel()
147+
148+
// Execute command with timeout handling
149+
output, err := executeCommandWithTimeout(ctx, session, args.Command)
124150
outputStr := strings.TrimSpace(string(output))
125151

152+
// Handle command execution results
126153
if err != nil {
127-
// Check if it's an SSH exit error to get the exit code
128-
var exitErr *gossh.ExitError
129-
if errors.As(err, &exitErr) {
154+
// Check if the command timed out
155+
if errors.Is(context.Cause(ctx), context.DeadlineExceeded) {
156+
outputStr += "\nCommand canceled due to timeout"
130157
return WorkspaceBashResult{
131158
Output: outputStr,
132-
ExitCode: exitErr.ExitStatus(),
159+
ExitCode: 124,
133160
}, nil
134161
}
135-
// For other errors, return exit code 1
162+
163+
// Extract exit code from SSH error if available
164+
exitCode := 1
165+
var exitErr *gossh.ExitError
166+
if errors.As(err, &exitErr) {
167+
exitCode = exitErr.ExitStatus()
168+
}
169+
170+
// For other errors, use standard timeout or generic error code
136171
return WorkspaceBashResult{
137172
Output: outputStr,
138-
ExitCode: 1,
173+
ExitCode: exitCode,
139174
}, nil
140175
}
141176

@@ -292,3 +327,99 @@ func NormalizeWorkspaceInput(input string) string {
292327

293328
return normalized
294329
}
330+
331+
// executeCommandWithTimeout executes a command with timeout support
332+
func executeCommandWithTimeout(ctx context.Context, session *gossh.Session, command string) ([]byte, error) {
333+
// Set up pipes to capture output
334+
stdoutPipe, err := session.StdoutPipe()
335+
if err != nil {
336+
return nil, xerrors.Errorf("failed to create stdout pipe: %w", err)
337+
}
338+
339+
stderrPipe, err := session.StderrPipe()
340+
if err != nil {
341+
return nil, xerrors.Errorf("failed to create stderr pipe: %w", err)
342+
}
343+
344+
// Start the command
345+
if err := session.Start(command); err != nil {
346+
return nil, xerrors.Errorf("failed to start command: %w", err)
347+
}
348+
349+
// Create a thread-safe buffer for combined output
350+
var output bytes.Buffer
351+
var mu sync.Mutex
352+
safeWriter := &syncWriter{w: &output, mu: &mu}
353+
354+
// Use io.MultiWriter to combine stdout and stderr
355+
multiWriter := io.MultiWriter(safeWriter)
356+
357+
// Channel to signal when command completes
358+
done := make(chan error, 1)
359+
360+
// Start goroutine to copy output and wait for completion
361+
go func() {
362+
// Copy stdout and stderr concurrently
363+
var wg sync.WaitGroup
364+
wg.Add(2)
365+
366+
go func() {
367+
defer wg.Done()
368+
_, _ = io.Copy(multiWriter, stdoutPipe)
369+
}()
370+
371+
go func() {
372+
defer wg.Done()
373+
_, _ = io.Copy(multiWriter, stderrPipe)
374+
}()
375+
376+
// Wait for all output to be copied
377+
wg.Wait()
378+
379+
// Wait for the command to complete
380+
done <- session.Wait()
381+
}()
382+
383+
// Wait for either completion or context cancellation
384+
select {
385+
case err := <-done:
386+
// Command completed normally
387+
return safeWriter.Bytes(), err
388+
case <-ctx.Done():
389+
// Context was canceled (timeout or other cancellation)
390+
// Close the session to stop the command
391+
_ = session.Close()
392+
393+
// Give a brief moment to collect any remaining output
394+
timer := time.NewTimer(50 * time.Millisecond)
395+
defer timer.Stop()
396+
397+
select {
398+
case <-timer.C:
399+
// Timer expired, return what we have
400+
case err := <-done:
401+
// Command finished during grace period
402+
return safeWriter.Bytes(), err
403+
}
404+
405+
return safeWriter.Bytes(), context.Cause(ctx)
406+
}
407+
}
408+
409+
// syncWriter is a thread-safe writer
410+
type syncWriter struct {
411+
w *bytes.Buffer
412+
mu *sync.Mutex
413+
}
414+
415+
func (sw *syncWriter) Write(p []byte) (n int, err error) {
416+
sw.mu.Lock()
417+
defer sw.mu.Unlock()
418+
return sw.w.Write(p)
419+
}
420+
421+
func (sw *syncWriter) Bytes() []byte {
422+
sw.mu.Lock()
423+
defer sw.mu.Unlock()
424+
return sw.w.Bytes()
425+
}

0 commit comments

Comments
 (0)