Skip to content

feat: add timeout support to workspace bash tool #19035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 141 additions & 10 deletions codersdk/toolsdk/bash.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package toolsdk

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"strings"
"sync"
"time"

gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
Expand All @@ -20,6 +23,7 @@ import (
type WorkspaceBashArgs struct {
Workspace string `json:"workspace"`
Command string `json:"command"`
TimeoutMs int `json:"timeout_ms,omitempty"`
}

type WorkspaceBashResult struct {
Expand All @@ -43,9 +47,12 @@ The workspace parameter supports various formats:
- workspace.agent (specific agent)
- owner/workspace.agent

The timeout_ms parameter specifies the command timeout in milliseconds (defaults to 60000ms, maximum of 300000ms).
If the command times out, all output captured up to that point is returned with a cancellation message.

Examples:
- workspace: "my-workspace", command: "ls -la"
- workspace: "john/dev-env", command: "git status"
- workspace: "john/dev-env", command: "git status", timeout_ms: 30000
- workspace: "my-workspace.main", command: "docker ps"`,
Schema: aisdk.Schema{
Properties: map[string]any{
Expand All @@ -57,18 +64,27 @@ Examples:
"type": "string",
"description": "The bash command to execute in the workspace.",
},
"timeout_ms": map[string]any{
"type": "integer",
"description": "Command timeout in milliseconds. Defaults to 60000ms (60 seconds) if not specified.",
"default": 60000,
"minimum": 1,
},
},
Required: []string{"workspace", "command"},
},
},
Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (WorkspaceBashResult, error) {
Handler: func(ctx context.Context, deps Deps, args WorkspaceBashArgs) (res WorkspaceBashResult, err error) {
if args.Workspace == "" {
return WorkspaceBashResult{}, xerrors.New("workspace name cannot be empty")
}
if args.Command == "" {
return WorkspaceBashResult{}, xerrors.New("command cannot be empty")
}

ctx, cancel := context.WithTimeoutCause(ctx, 5*time.Minute, xerrors.New("MCP handler timeout after 5 min"))
defer cancel()

// Normalize workspace input to handle various formats
workspaceName := NormalizeWorkspaceInput(args.Workspace)

Expand Down Expand Up @@ -119,23 +135,42 @@ Examples:
}
defer session.Close()

// Execute command and capture output
output, err := session.CombinedOutput(args.Command)
// Set default timeout if not specified (60 seconds)
timeoutMs := args.TimeoutMs
if timeoutMs <= 0 {
timeoutMs = 60000
}

// Create context with timeout
ctx, cancel = context.WithTimeout(ctx, time.Duration(timeoutMs)*time.Millisecond)
defer cancel()

// Execute command with timeout handling
output, err := executeCommandWithTimeout(ctx, session, args.Command)
outputStr := strings.TrimSpace(string(output))

// Handle command execution results
if err != nil {
// Check if it's an SSH exit error to get the exit code
var exitErr *gossh.ExitError
if errors.As(err, &exitErr) {
// Check if the command timed out
if errors.Is(context.Cause(ctx), context.DeadlineExceeded) {
outputStr += "\nCommand canceled due to timeout"
return WorkspaceBashResult{
Output: outputStr,
ExitCode: exitErr.ExitStatus(),
ExitCode: 124,
}, nil
}
// For other errors, return exit code 1

// Extract exit code from SSH error if available
exitCode := 1
var exitErr *gossh.ExitError
if errors.As(err, &exitErr) {
exitCode = exitErr.ExitStatus()
}

// For other errors, use standard timeout or generic error code
return WorkspaceBashResult{
Output: outputStr,
ExitCode: 1,
ExitCode: exitCode,
}, nil
}

Expand Down Expand Up @@ -292,3 +327,99 @@ func NormalizeWorkspaceInput(input string) string {

return normalized
}

// executeCommandWithTimeout executes a command with timeout support
func executeCommandWithTimeout(ctx context.Context, session *gossh.Session, command string) ([]byte, error) {
// Set up pipes to capture output
stdoutPipe, err := session.StdoutPipe()
if err != nil {
return nil, xerrors.Errorf("failed to create stdout pipe: %w", err)
}

stderrPipe, err := session.StderrPipe()
if err != nil {
return nil, xerrors.Errorf("failed to create stderr pipe: %w", err)
}

// Start the command
if err := session.Start(command); err != nil {
return nil, xerrors.Errorf("failed to start command: %w", err)
}

// Create a thread-safe buffer for combined output
var output bytes.Buffer
var mu sync.Mutex
safeWriter := &syncWriter{w: &output, mu: &mu}

// Use io.MultiWriter to combine stdout and stderr
multiWriter := io.MultiWriter(safeWriter)

// Channel to signal when command completes
done := make(chan error, 1)

// Start goroutine to copy output and wait for completion
go func() {
// Copy stdout and stderr concurrently
var wg sync.WaitGroup
wg.Add(2)

go func() {
defer wg.Done()
_, _ = io.Copy(multiWriter, stdoutPipe)
}()

go func() {
defer wg.Done()
_, _ = io.Copy(multiWriter, stderrPipe)
}()

// Wait for all output to be copied
wg.Wait()

// Wait for the command to complete
done <- session.Wait()
}()

// Wait for either completion or context cancellation
select {
case err := <-done:
// Command completed normally
return safeWriter.Bytes(), err
case <-ctx.Done():
// Context was canceled (timeout or other cancellation)
// Close the session to stop the command
_ = session.Close()

// Give a brief moment to collect any remaining output
timer := time.NewTimer(50 * time.Millisecond)
defer timer.Stop()

select {
case <-timer.C:
// Timer expired, return what we have
case err := <-done:
// Command finished during grace period
return safeWriter.Bytes(), err
}

return safeWriter.Bytes(), context.Cause(ctx)
}
}

// syncWriter is a thread-safe writer
type syncWriter struct {
w *bytes.Buffer
mu *sync.Mutex
}

func (sw *syncWriter) Write(p []byte) (n int, err error) {
sw.mu.Lock()
defer sw.mu.Unlock()
return sw.w.Write(p)
}

func (sw *syncWriter) Bytes() []byte {
sw.mu.Lock()
defer sw.mu.Unlock()
return sw.w.Bytes()
}
Loading