Skip to content

first pass: notify agents when a prebuilt workspace has been claimed. #18929

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 192 additions & 19 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@
}

type Client interface {
ConnectRPC26(ctx context.Context) (
proto.DRPCAgentClient26, tailnetproto.DRPCTailnetClient26, error,
ConnectRPC27(ctx context.Context) (
proto.DRPCAgentClient27, tailnetproto.DRPCTailnetClient26, error,
)
RewriteDERPMap(derpMap *tailcfg.DERPMap)
}
Expand All @@ -105,6 +105,23 @@
HTTPDebug() http.Handler
// TailnetConn may be nil.
TailnetConn() *tailnet.Conn
// SubscribeToClaimEvents allows other routines to subscribe to workspace claim events
//
// Example usage:
// unsubscribe := agent.SubscribeToClaimEvents(func(ctx context.Context, data interface{}) {
// claimData := data.(map[string]interface{})
// workspaceID := claimData["workspace_id"].(uuid.UUID)
// agentID := claimData["agent_id"].(uuid.UUID)
// claimedAt := claimData["claimed_at"].(time.Time)
//
// // React to the claim event
// // - Restart services that need user context
// // - Update configuration files
// // - Send notifications
// // etc.
// })
// defer unsubscribe()
SubscribeToClaimEvents(listener func(ctx context.Context, data interface{})) func()
io.Closer
}

Expand Down Expand Up @@ -196,6 +213,7 @@

devcontainers: options.Devcontainers,
containerAPIOptions: options.DevcontainerAPIOptions,
eventPubsub: newAgentEventPubsub(),
}
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
Expand Down Expand Up @@ -280,6 +298,77 @@
devcontainers bool
containerAPIOptions []agentcontainers.Option
containerAPI *agentcontainers.API

// In-memory pubsub for agent events
eventPubsub *agentEventPubsub
}

// agentEventPubsub provides an in-memory pubsub system for agent events
type agentEventPubsub struct {
mu sync.RWMutex
listeners map[string]map[uuid.UUID]func(ctx context.Context, data interface{})
}

func newAgentEventPubsub() *agentEventPubsub {
return &agentEventPubsub{
listeners: make(map[string]map[uuid.UUID]func(ctx context.Context, data interface{})),
}
}

func (p *agentEventPubsub) Subscribe(event string, listener func(ctx context.Context, data interface{})) func() {
p.mu.Lock()
defer p.mu.Unlock()

if p.listeners[event] == nil {
p.listeners[event] = make(map[uuid.UUID]func(ctx context.Context, data interface{}))
}

// Generate unique ID for this listener
listenerID := uuid.New()
p.listeners[event][listenerID] = listener

// Return unsubscribe function
return func() {
p.mu.Lock()
defer p.mu.Unlock()

if listeners, exists := p.listeners[event]; exists {
delete(listeners, listenerID)
// Clean up empty event maps
if len(listeners) == 0 {
delete(p.listeners, event)
}
}
}
}

func (p *agentEventPubsub) Publish(ctx context.Context, event string, data interface{}) {
p.mu.RLock()
listeners, exists := p.listeners[event]
if !exists {
p.mu.RUnlock()
return
}

// Create a copy of listeners to avoid holding the lock while calling them
listenersCopy := make(map[uuid.UUID]func(ctx context.Context, data interface{}))
for id, listener := range listeners {
listenersCopy[id] = listener
}
p.mu.RUnlock()

// Call all listeners in goroutines to avoid blocking
for _, listener := range listenersCopy {
go func(l func(ctx context.Context, data interface{})) {
defer func() {
if r := recover(); r != nil {
// Log panic but don't crash the agent
fmt.Printf("panic in agent event listener: %v\n", r)

Check failure on line 366 in agent/agent.go

View workflow job for this annotation

GitHub Actions / lint

unhandled-error: Unhandled error in call to function fmt.Printf (revive)
}
}()
l(ctx, data)
}(listener)
}
}

func (a *agent) TailnetConn() *tailnet.Conn {
Expand All @@ -288,6 +377,10 @@
return a.network
}

func (a *agent) SubscribeToClaimEvents(listener func(ctx context.Context, data interface{})) func() {
return a.eventPubsub.Subscribe("workspace_claimed", listener)
}

func (a *agent) init() {
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.execer, &agentssh.Config{
Expand Down Expand Up @@ -472,7 +565,7 @@
fn()
}

func (a *agent) reportMetadata(ctx context.Context, aAPI proto.DRPCAgentClient26) error {
func (a *agent) reportMetadata(ctx context.Context, aAPI proto.DRPCAgentClient27) error {
tickerDone := make(chan struct{})
collectDone := make(chan struct{})
ctx, cancel := context.WithCancel(ctx)
Expand Down Expand Up @@ -687,7 +780,7 @@

// reportLifecycle reports the current lifecycle state once. All state
// changes are reported in order.
func (a *agent) reportLifecycle(ctx context.Context, aAPI proto.DRPCAgentClient26) error {
func (a *agent) reportLifecycle(ctx context.Context, aAPI proto.DRPCAgentClient27) error {
for {
select {
case <-a.lifecycleUpdate:
Expand Down Expand Up @@ -767,7 +860,7 @@
}

// reportConnectionsLoop reports connections to the agent for auditing.
func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentClient26) error {
func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentClient27) error {
for {
select {
case <-a.reportConnectionsUpdate:
Expand Down Expand Up @@ -887,7 +980,7 @@
// fetchServiceBannerLoop fetches the service banner on an interval. It will
// not be fetched immediately; the expectation is that it is primed elsewhere
// (and must be done before the session actually starts).
func (a *agent) fetchServiceBannerLoop(ctx context.Context, aAPI proto.DRPCAgentClient26) error {
func (a *agent) fetchServiceBannerLoop(ctx context.Context, aAPI proto.DRPCAgentClient27) error {
ticker := time.NewTicker(a.announcementBannersRefreshInterval)
defer ticker.Stop()
for {
Expand Down Expand Up @@ -923,7 +1016,7 @@
a.sessionToken.Store(&sessionToken)

// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs
aAPI, tAPI, err := a.client.ConnectRPC26(a.hardCtx)
aAPI, tAPI, err := a.client.ConnectRPC27(a.hardCtx)
if err != nil {
return err
}
Expand All @@ -940,7 +1033,7 @@
connMan := newAPIConnRoutineManager(a.gracefulCtx, a.hardCtx, a.logger, aAPI, tAPI)

connMan.startAgentAPI("init notification banners", gracefulShutdownBehaviorStop,
func(ctx context.Context, aAPI proto.DRPCAgentClient26) error {
func(ctx context.Context, aAPI proto.DRPCAgentClient27) error {
bannersProto, err := aAPI.GetAnnouncementBanners(ctx, &proto.GetAnnouncementBannersRequest{})
if err != nil {
return xerrors.Errorf("fetch service banner: %w", err)
Expand All @@ -957,7 +1050,7 @@
// sending logs gets gracefulShutdownBehaviorRemain because we want to send logs generated by
// shutdown scripts.
connMan.startAgentAPI("send logs", gracefulShutdownBehaviorRemain,
func(ctx context.Context, aAPI proto.DRPCAgentClient26) error {
func(ctx context.Context, aAPI proto.DRPCAgentClient27) error {
err := a.logSender.SendLoop(ctx, aAPI)
if xerrors.Is(err, agentsdk.ErrLogLimitExceeded) {
// we don't want this error to tear down the API connection and propagate to the
Expand All @@ -976,7 +1069,7 @@
connMan.startAgentAPI("report metadata", gracefulShutdownBehaviorStop, a.reportMetadata)

// resources monitor can cease as soon as we start gracefully shutting down.
connMan.startAgentAPI("resources monitor", gracefulShutdownBehaviorStop, func(ctx context.Context, aAPI proto.DRPCAgentClient26) error {
connMan.startAgentAPI("resources monitor", gracefulShutdownBehaviorStop, func(ctx context.Context, aAPI proto.DRPCAgentClient27) error {
logger := a.logger.Named("resources_monitor")
clk := quartz.NewReal()
config, err := aAPI.GetResourcesMonitoringConfiguration(ctx, &proto.GetResourcesMonitoringConfigurationRequest{})
Expand Down Expand Up @@ -1023,7 +1116,7 @@
connMan.startAgentAPI("handle manifest", gracefulShutdownBehaviorStop, a.handleManifest(manifestOK))

connMan.startAgentAPI("app health reporter", gracefulShutdownBehaviorStop,
func(ctx context.Context, aAPI proto.DRPCAgentClient26) error {
func(ctx context.Context, aAPI proto.DRPCAgentClient27) error {
if err := manifestOK.wait(ctx); err != nil {
return xerrors.Errorf("no manifest: %w", err)
}
Expand Down Expand Up @@ -1056,13 +1149,93 @@

connMan.startAgentAPI("fetch service banner loop", gracefulShutdownBehaviorStop, a.fetchServiceBannerLoop)

connMan.startAgentAPI("stats report loop", gracefulShutdownBehaviorStop, func(ctx context.Context, aAPI proto.DRPCAgentClient26) error {
connMan.startAgentAPI("stats report loop", gracefulShutdownBehaviorStop, func(ctx context.Context, aAPI proto.DRPCAgentClient27) error {
if err := networkOK.wait(ctx); err != nil {
return xerrors.Errorf("no network: %w", err)
}
return a.statsReporter.reportLoop(ctx, aAPI)
})

// Stream prebuild status to handle prebuilt workspace claims
connMan.startAgentAPI("stream prebuild status", gracefulShutdownBehaviorStop,
func(ctx context.Context, aAPI proto.DRPCAgentClient27) error {
if err := manifestOK.wait(ctx); err != nil {
return xerrors.Errorf("no manifest: %w", err)
}

a.logger.Debug(ctx, "starting prebuild status stream")

// Start streaming prebuild status
stream, err := aAPI.StreamPrebuildStatus(ctx, &proto.StreamPrebuildStatusRequest{})
if err != nil {
return xerrors.Errorf("start prebuild status stream: %w", err)
}

// Track previous status to detect transitions
var previousStatus proto.PrebuildStatus
isFirstStatus := true

// Process prebuild status updates
for {
response, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
a.logger.Info(ctx, "prebuild status stream ended")
return nil
}
return xerrors.Errorf("receive prebuild status: %w", err)
}

a.logger.Debug(ctx, "received prebuild status update",
slog.F("status", response.Status.String()),
slog.F("updated_at", response.UpdatedAt.AsTime()),
)

// Handle different prebuild statuses
switch response.Status {
case proto.PrebuildStatus_PREBUILD_CLAIM_STATUS_UNCLAIMED:
a.logger.Info(ctx, "prebuilt workspace is unclaimed, waiting for claim")
// Continue waiting for claim
case proto.PrebuildStatus_PREBUILD_CLAIM_STATUS_CLAIMED:
// Check if this is a transition from UNCLAIMED to CLAIMED
if !isFirstStatus && previousStatus == proto.PrebuildStatus_PREBUILD_CLAIM_STATUS_UNCLAIMED {
a.logger.Info(ctx, "prebuilt workspace has been claimed")

// Publish claim event for other routines to react to
a.eventPubsub.Publish(ctx, "workspace_claimed", map[string]interface{}{
"workspace_id": a.manifest.Load().WorkspaceID,
"agent_id": a.manifest.Load().AgentID,
"claimed_at": response.UpdatedAt.AsTime(),
})
}

// Prebuilt workspace has been claimed, we can continue
return nil
case proto.PrebuildStatus_PREBUILD_CLAIM_STATUS_NORMAL:
// Check if this is a transition from UNCLAIMED to NORMAL
if !isFirstStatus && previousStatus == proto.PrebuildStatus_PREBUILD_CLAIM_STATUS_UNCLAIMED {

Check failure on line 1216 in agent/agent.go

View workflow job for this annotation

GitHub Actions / lint

empty-lines: extra empty line at the end of a block (revive)
a.logger.Info(ctx, "prebuilt workspace has been claimed")

// Publish claim event for other routines to react to
a.eventPubsub.Publish(ctx, "workspace_claimed", map[string]interface{}{
"workspace_id": a.manifest.Load().WorkspaceID,
"agent_id": a.manifest.Load().AgentID,
"claimed_at": response.UpdatedAt.AsTime(),
})

}
// This is a normal workspace, no need to wait for claim
return nil
default:
a.logger.Warn(ctx, "unknown prebuild status", slog.F("status", response.Status))
}

// Update previous status for next iteration
previousStatus = response.Status
isFirstStatus = false
}
})

err = connMan.wait()
if err != nil {
a.logger.Info(context.Background(), "connection manager errored", slog.Error(err))
Expand All @@ -1071,8 +1244,8 @@
}

// handleManifest returns a function that fetches and processes the manifest
func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context, aAPI proto.DRPCAgentClient26) error {
return func(ctx context.Context, aAPI proto.DRPCAgentClient26) error {
func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context, aAPI proto.DRPCAgentClient27) error {
return func(ctx context.Context, aAPI proto.DRPCAgentClient27) error {
var (
sentResult = false
err error
Expand Down Expand Up @@ -1267,8 +1440,8 @@

// createOrUpdateNetwork waits for the manifest to be set using manifestOK, then creates or updates
// the tailnet using the information in the manifest
func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(context.Context, proto.DRPCAgentClient26) error {
return func(ctx context.Context, aAPI proto.DRPCAgentClient26) (retErr error) {
func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(context.Context, proto.DRPCAgentClient27) error {
return func(ctx context.Context, aAPI proto.DRPCAgentClient27) (retErr error) {
if err := manifestOK.wait(ctx); err != nil {
return xerrors.Errorf("no manifest: %w", err)
}
Expand Down Expand Up @@ -2047,7 +2220,7 @@

type apiConnRoutineManager struct {
logger slog.Logger
aAPI proto.DRPCAgentClient26
aAPI proto.DRPCAgentClient27
tAPI tailnetproto.DRPCTailnetClient24
eg *errgroup.Group
stopCtx context.Context
Expand All @@ -2056,7 +2229,7 @@

func newAPIConnRoutineManager(
gracefulCtx, hardCtx context.Context, logger slog.Logger,
aAPI proto.DRPCAgentClient26, tAPI tailnetproto.DRPCTailnetClient24,
aAPI proto.DRPCAgentClient27, tAPI tailnetproto.DRPCTailnetClient24,
) *apiConnRoutineManager {
// routines that remain in operation during graceful shutdown use the remainCtx. They'll still
// exit if the errgroup hits an error, which usually means a problem with the conn.
Expand Down Expand Up @@ -2089,7 +2262,7 @@
// but for Tailnet.
func (a *apiConnRoutineManager) startAgentAPI(
name string, behavior gracefulShutdownBehavior,
f func(context.Context, proto.DRPCAgentClient26) error,
f func(context.Context, proto.DRPCAgentClient27) error,
) {
logger := a.logger.With(slog.F("name", name))
var ctx context.Context
Expand Down
4 changes: 2 additions & 2 deletions agent/agentcontainers/subagent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func TestSubAgentClient_CreateWithDisplayApps(t *testing.T) {

agentAPI := agenttest.NewClient(t, logger, uuid.New(), agentsdk.Manifest{}, statsCh, tailnet.NewCoordinator(logger))

agentClient, _, err := agentAPI.ConnectRPC26(ctx)
agentClient, _, err := agentAPI.ConnectRPC27(ctx)
require.NoError(t, err)

subAgentClient := agentcontainers.NewSubAgentClientFromAPI(logger, agentClient)
Expand Down Expand Up @@ -245,7 +245,7 @@ func TestSubAgentClient_CreateWithDisplayApps(t *testing.T) {

agentAPI := agenttest.NewClient(t, logger, uuid.New(), agentsdk.Manifest{}, statsCh, tailnet.NewCoordinator(logger))

agentClient, _, err := agentAPI.ConnectRPC26(ctx)
agentClient, _, err := agentAPI.ConnectRPC27(ctx)
require.NoError(t, err)

subAgentClient := agentcontainers.NewSubAgentClientFromAPI(logger, agentClient)
Expand Down
Loading
Loading