Skip to content

Commit 43b0bb7

Browse files
feat(site): use websocket connection for devcontainer updates (#18808)
Instead of polling every 10 seconds, we instead use a WebSocket connection for more timely updates.
1 parent 7cf3263 commit 43b0bb7

File tree

15 files changed

+1079
-23
lines changed

15 files changed

+1079
-23
lines changed

agent/agentcontainers/api.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ package agentcontainers
22

33
import (
44
"context"
5+
"encoding/json"
56
"errors"
67
"fmt"
8+
"maps"
79
"net/http"
810
"os"
911
"path"
@@ -30,6 +32,7 @@ import (
3032
"github.com/coder/coder/v2/codersdk/agentsdk"
3133
"github.com/coder/coder/v2/provisioner"
3234
"github.com/coder/quartz"
35+
"github.com/coder/websocket"
3336
)
3437

3538
const (
@@ -74,6 +77,7 @@ type API struct {
7477

7578
mu sync.RWMutex // Protects the following fields.
7679
initDone chan struct{} // Closed by Init.
80+
updateChans []chan struct{}
7781
closed bool
7882
containers codersdk.WorkspaceAgentListContainersResponse // Output from the last list operation.
7983
containersErr error // Error from the last list operation.
@@ -535,6 +539,7 @@ func (api *API) Routes() http.Handler {
535539
r.Use(ensureInitDoneMW)
536540

537541
r.Get("/", api.handleList)
542+
r.Get("/watch", api.watchContainers)
538543
// TODO(mafredri): Simplify this route as the previous /devcontainers
539544
// /-route was dropped. We can drop the /devcontainers prefix here too.
540545
r.Route("/devcontainers/{devcontainer}", func(r chi.Router) {
@@ -544,6 +549,88 @@ func (api *API) Routes() http.Handler {
544549
return r
545550
}
546551

552+
func (api *API) broadcastUpdatesLocked() {
553+
// Broadcast state changes to WebSocket listeners.
554+
for _, ch := range api.updateChans {
555+
select {
556+
case ch <- struct{}{}:
557+
default:
558+
}
559+
}
560+
}
561+
562+
func (api *API) watchContainers(rw http.ResponseWriter, r *http.Request) {
563+
ctx := r.Context()
564+
565+
conn, err := websocket.Accept(rw, r, nil)
566+
if err != nil {
567+
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
568+
Message: "Failed to upgrade connection to websocket.",
569+
Detail: err.Error(),
570+
})
571+
return
572+
}
573+
574+
// Here we close the websocket for reading, so that the websocket library will handle pings and
575+
// close frames.
576+
_ = conn.CloseRead(context.Background())
577+
578+
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
579+
defer wsNetConn.Close()
580+
581+
go httpapi.Heartbeat(ctx, conn)
582+
583+
updateCh := make(chan struct{}, 1)
584+
585+
api.mu.Lock()
586+
api.updateChans = append(api.updateChans, updateCh)
587+
api.mu.Unlock()
588+
589+
defer func() {
590+
api.mu.Lock()
591+
api.updateChans = slices.DeleteFunc(api.updateChans, func(ch chan struct{}) bool {
592+
return ch == updateCh
593+
})
594+
close(updateCh)
595+
api.mu.Unlock()
596+
}()
597+
598+
encoder := json.NewEncoder(wsNetConn)
599+
600+
ct, err := api.getContainers()
601+
if err != nil {
602+
api.logger.Error(ctx, "unable to get containers", slog.Error(err))
603+
return
604+
}
605+
606+
if err := encoder.Encode(ct); err != nil {
607+
api.logger.Error(ctx, "encode container list", slog.Error(err))
608+
return
609+
}
610+
611+
for {
612+
select {
613+
case <-api.ctx.Done():
614+
return
615+
616+
case <-ctx.Done():
617+
return
618+
619+
case <-updateCh:
620+
ct, err := api.getContainers()
621+
if err != nil {
622+
api.logger.Error(ctx, "unable to get containers", slog.Error(err))
623+
continue
624+
}
625+
626+
if err := encoder.Encode(ct); err != nil {
627+
api.logger.Error(ctx, "encode container list", slog.Error(err))
628+
return
629+
}
630+
}
631+
}
632+
}
633+
547634
// handleList handles the HTTP request to list containers.
548635
func (api *API) handleList(rw http.ResponseWriter, r *http.Request) {
549636
ct, err := api.getContainers()
@@ -583,8 +670,26 @@ func (api *API) updateContainers(ctx context.Context) error {
583670
api.mu.Lock()
584671
defer api.mu.Unlock()
585672

673+
var previouslyKnownDevcontainers map[string]codersdk.WorkspaceAgentDevcontainer
674+
if len(api.updateChans) > 0 {
675+
previouslyKnownDevcontainers = maps.Clone(api.knownDevcontainers)
676+
}
677+
586678
api.processUpdatedContainersLocked(ctx, updated)
587679

680+
if len(api.updateChans) > 0 {
681+
statesAreEqual := maps.EqualFunc(
682+
previouslyKnownDevcontainers,
683+
api.knownDevcontainers,
684+
func(dc1, dc2 codersdk.WorkspaceAgentDevcontainer) bool {
685+
return dc1.Equals(dc2)
686+
})
687+
688+
if !statesAreEqual {
689+
api.broadcastUpdatesLocked()
690+
}
691+
}
692+
588693
api.logger.Debug(ctx, "containers updated successfully", slog.F("container_count", len(api.containers.Containers)), slog.F("warning_count", len(api.containers.Warnings)), slog.F("devcontainer_count", len(api.knownDevcontainers)))
589694

590695
return nil
@@ -955,6 +1060,8 @@ func (api *API) handleDevcontainerRecreate(w http.ResponseWriter, r *http.Reques
9551060
dc.Container = nil
9561061
dc.Error = ""
9571062
api.knownDevcontainers[dc.WorkspaceFolder] = dc
1063+
api.broadcastUpdatesLocked()
1064+
9581065
go func() {
9591066
_ = api.CreateDevcontainer(dc.WorkspaceFolder, dc.ConfigPath, WithRemoveExistingContainer())
9601067
}()
@@ -1070,6 +1177,7 @@ func (api *API) CreateDevcontainer(workspaceFolder, configPath string, opts ...D
10701177
dc.Error = ""
10711178
api.recreateSuccessTimes[dc.WorkspaceFolder] = api.clock.Now("agentcontainers", "recreate", "successTimes")
10721179
api.knownDevcontainers[dc.WorkspaceFolder] = dc
1180+
api.broadcastUpdatesLocked()
10731181
api.mu.Unlock()
10741182

10751183
// Ensure an immediate refresh to accurately reflect the

agent/agentcontainers/api_test.go

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636
"github.com/coder/coder/v2/pty"
3737
"github.com/coder/coder/v2/testutil"
3838
"github.com/coder/quartz"
39+
"github.com/coder/websocket"
3940
)
4041

4142
// fakeContainerCLI implements the agentcontainers.ContainerCLI interface for
@@ -441,6 +442,178 @@ func TestAPI(t *testing.T) {
441442
logbuf.Reset()
442443
})
443444

445+
t.Run("Watch", func(t *testing.T) {
446+
t.Parallel()
447+
448+
fakeContainer1 := fakeContainer(t, func(c *codersdk.WorkspaceAgentContainer) {
449+
c.ID = "container1"
450+
c.FriendlyName = "devcontainer1"
451+
c.Image = "busybox:latest"
452+
c.Labels = map[string]string{
453+
agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project1",
454+
agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project1/.devcontainer/devcontainer.json",
455+
}
456+
})
457+
458+
fakeContainer2 := fakeContainer(t, func(c *codersdk.WorkspaceAgentContainer) {
459+
c.ID = "container2"
460+
c.FriendlyName = "devcontainer2"
461+
c.Image = "ubuntu:latest"
462+
c.Labels = map[string]string{
463+
agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project2",
464+
agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project2/.devcontainer/devcontainer.json",
465+
}
466+
})
467+
468+
stages := []struct {
469+
containers []codersdk.WorkspaceAgentContainer
470+
expected codersdk.WorkspaceAgentListContainersResponse
471+
}{
472+
{
473+
containers: []codersdk.WorkspaceAgentContainer{fakeContainer1},
474+
expected: codersdk.WorkspaceAgentListContainersResponse{
475+
Containers: []codersdk.WorkspaceAgentContainer{fakeContainer1},
476+
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
477+
{
478+
Name: "project1",
479+
WorkspaceFolder: fakeContainer1.Labels[agentcontainers.DevcontainerLocalFolderLabel],
480+
ConfigPath: fakeContainer1.Labels[agentcontainers.DevcontainerConfigFileLabel],
481+
Status: "running",
482+
Container: &fakeContainer1,
483+
},
484+
},
485+
},
486+
},
487+
{
488+
containers: []codersdk.WorkspaceAgentContainer{fakeContainer1, fakeContainer2},
489+
expected: codersdk.WorkspaceAgentListContainersResponse{
490+
Containers: []codersdk.WorkspaceAgentContainer{fakeContainer1, fakeContainer2},
491+
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
492+
{
493+
Name: "project1",
494+
WorkspaceFolder: fakeContainer1.Labels[agentcontainers.DevcontainerLocalFolderLabel],
495+
ConfigPath: fakeContainer1.Labels[agentcontainers.DevcontainerConfigFileLabel],
496+
Status: "running",
497+
Container: &fakeContainer1,
498+
},
499+
{
500+
Name: "project2",
501+
WorkspaceFolder: fakeContainer2.Labels[agentcontainers.DevcontainerLocalFolderLabel],
502+
ConfigPath: fakeContainer2.Labels[agentcontainers.DevcontainerConfigFileLabel],
503+
Status: "running",
504+
Container: &fakeContainer2,
505+
},
506+
},
507+
},
508+
},
509+
{
510+
containers: []codersdk.WorkspaceAgentContainer{fakeContainer2},
511+
expected: codersdk.WorkspaceAgentListContainersResponse{
512+
Containers: []codersdk.WorkspaceAgentContainer{fakeContainer2},
513+
Devcontainers: []codersdk.WorkspaceAgentDevcontainer{
514+
{
515+
Name: "",
516+
WorkspaceFolder: fakeContainer1.Labels[agentcontainers.DevcontainerLocalFolderLabel],
517+
ConfigPath: fakeContainer1.Labels[agentcontainers.DevcontainerConfigFileLabel],
518+
Status: "stopped",
519+
Container: nil,
520+
},
521+
{
522+
Name: "project2",
523+
WorkspaceFolder: fakeContainer2.Labels[agentcontainers.DevcontainerLocalFolderLabel],
524+
ConfigPath: fakeContainer2.Labels[agentcontainers.DevcontainerConfigFileLabel],
525+
Status: "running",
526+
Container: &fakeContainer2,
527+
},
528+
},
529+
},
530+
},
531+
}
532+
533+
var (
534+
ctx = testutil.Context(t, testutil.WaitShort)
535+
mClock = quartz.NewMock(t)
536+
updaterTickerTrap = mClock.Trap().TickerFunc("updaterLoop")
537+
mCtrl = gomock.NewController(t)
538+
mLister = acmock.NewMockContainerCLI(mCtrl)
539+
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
540+
)
541+
542+
// Set up initial state for immediate send on connection
543+
mLister.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{Containers: stages[0].containers}, nil)
544+
mLister.EXPECT().DetectArchitecture(gomock.Any(), gomock.Any()).Return("<none>", nil).AnyTimes()
545+
546+
api := agentcontainers.NewAPI(logger,
547+
agentcontainers.WithClock(mClock),
548+
agentcontainers.WithContainerCLI(mLister),
549+
agentcontainers.WithWatcher(watcher.NewNoop()),
550+
)
551+
api.Start()
552+
defer api.Close()
553+
554+
srv := httptest.NewServer(api.Routes())
555+
defer srv.Close()
556+
557+
updaterTickerTrap.MustWait(ctx).MustRelease(ctx)
558+
defer updaterTickerTrap.Close()
559+
560+
client, res, err := websocket.Dial(ctx, srv.URL+"/watch", nil)
561+
require.NoError(t, err)
562+
if res != nil && res.Body != nil {
563+
defer res.Body.Close()
564+
}
565+
566+
// Read initial state sent immediately on connection
567+
mt, msg, err := client.Read(ctx)
568+
require.NoError(t, err)
569+
require.Equal(t, websocket.MessageText, mt)
570+
571+
var got codersdk.WorkspaceAgentListContainersResponse
572+
err = json.Unmarshal(msg, &got)
573+
require.NoError(t, err)
574+
575+
require.Equal(t, stages[0].expected.Containers, got.Containers)
576+
require.Len(t, got.Devcontainers, len(stages[0].expected.Devcontainers))
577+
for j, expectedDev := range stages[0].expected.Devcontainers {
578+
gotDev := got.Devcontainers[j]
579+
require.Equal(t, expectedDev.Name, gotDev.Name)
580+
require.Equal(t, expectedDev.WorkspaceFolder, gotDev.WorkspaceFolder)
581+
require.Equal(t, expectedDev.ConfigPath, gotDev.ConfigPath)
582+
require.Equal(t, expectedDev.Status, gotDev.Status)
583+
require.Equal(t, expectedDev.Container, gotDev.Container)
584+
}
585+
586+
// Process remaining stages through updater loop
587+
for i, stage := range stages[1:] {
588+
mLister.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{Containers: stage.containers}, nil)
589+
590+
// Given: We allow the update loop to progress
591+
_, aw := mClock.AdvanceNext()
592+
aw.MustWait(ctx)
593+
594+
// When: We attempt to read a message from the socket.
595+
mt, msg, err := client.Read(ctx)
596+
require.NoError(t, err)
597+
require.Equal(t, websocket.MessageText, mt)
598+
599+
// Then: We expect the receieved message matches the expected response.
600+
var got codersdk.WorkspaceAgentListContainersResponse
601+
err = json.Unmarshal(msg, &got)
602+
require.NoError(t, err)
603+
604+
require.Equal(t, stages[i+1].expected.Containers, got.Containers)
605+
require.Len(t, got.Devcontainers, len(stages[i+1].expected.Devcontainers))
606+
for j, expectedDev := range stages[i+1].expected.Devcontainers {
607+
gotDev := got.Devcontainers[j]
608+
require.Equal(t, expectedDev.Name, gotDev.Name)
609+
require.Equal(t, expectedDev.WorkspaceFolder, gotDev.WorkspaceFolder)
610+
require.Equal(t, expectedDev.ConfigPath, gotDev.ConfigPath)
611+
require.Equal(t, expectedDev.Status, gotDev.Status)
612+
require.Equal(t, expectedDev.Container, gotDev.Container)
613+
}
614+
}
615+
})
616+
444617
// List tests the API.getContainers method using a mock
445618
// implementation. It specifically tests caching behavior.
446619
t.Run("List", func(t *testing.T) {

coderd/apidoc/docs.go

Lines changed: 35 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)