Skip to content

Commit 165e7dc

Browse files
committed
chore: add OAuth2 device flow test scripts
Change-Id: Ic232851727e683ab3d8b7ce970c505588da2f827 Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent 321563b commit 165e7dc

30 files changed

+1111
-679
lines changed

.claude/scripts/format.sh

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,30 +101,36 @@ fi
101101
# Get the file extension to determine the appropriate formatter
102102
file_ext="${file_path##*.}"
103103

104+
# Helper function to run formatter and handle errors
105+
run_formatter() {
106+
local target="$1"
107+
local file_type="$2"
108+
109+
if ! make FILE="$file_path" "$target"; then
110+
echo "Error: Failed to format $file_type file: $file_path" >&2
111+
exit 2
112+
fi
113+
echo "✓ Formatted $file_type file: $file_path"
114+
}
104115
# Change to the project root directory (where the Makefile is located)
105116
cd "$(dirname "$0")/../.."
106117

107118
# Call the appropriate Makefile target based on file extension
108119
case "$file_ext" in
109120
go)
110-
make fmt/go FILE="$file_path"
111-
echo "✓ Formatted Go file: $file_path"
121+
run_formatter "fmt/go" "Go"
112122
;;
113123
js | jsx | ts | tsx)
114-
make fmt/ts FILE="$file_path"
115-
echo "✓ Formatted TypeScript/JavaScript file: $file_path"
124+
run_formatter "fmt/ts" "TypeScript/JavaScript"
116125
;;
117126
tf | tfvars)
118-
make fmt/terraform FILE="$file_path"
119-
echo "✓ Formatted Terraform file: $file_path"
127+
run_formatter "fmt/terraform" "Terraform"
120128
;;
121129
sh)
122-
make fmt/shfmt FILE="$file_path"
123-
echo "✓ Formatted shell script: $file_path"
130+
run_formatter "fmt/shfmt" "shell script"
124131
;;
125132
md)
126-
make fmt/markdown FILE="$file_path"
127-
echo "✓ Formatted Markdown file: $file_path"
133+
run_formatter "fmt/markdown" "Markdown"
128134
;;
129135
*)
130136
echo "No formatter available for file extension: $file_ext"

coderd/coderd.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,7 @@ func New(options *Options) *API {
984984
r.Route("/device", func(r chi.Router) {
985985
r.Post("/", api.postOAuth2DeviceAuthorization()) // RFC 8628 compliant endpoint
986986
r.Route("/verify", func(r chi.Router) {
987-
r.Use(apiKeyMiddleware)
987+
r.Use(apiKeyMiddlewareRedirect)
988988
r.Get("/", api.getOAuth2DeviceVerification())
989989
r.Post("/", api.postOAuth2DeviceVerification())
990990
})

coderd/database/dbauthz/dbauthz.go

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ var (
417417
rbac.ResourceProvisionerJobs.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreate},
418418
rbac.ResourceOauth2App.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
419419
rbac.ResourceOauth2AppSecret.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
420+
rbac.ResourceOauth2AppCodeToken.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
420421
}),
421422
Org: map[string][]rbac.Permission{},
422423
User: []rbac.Permission{},
@@ -1346,6 +1347,14 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error {
13461347
return q.db.CleanTailnetTunnels(ctx)
13471348
}
13481349

1350+
func (q *querier) ConsumeOAuth2ProviderAppCodeByPrefix(ctx context.Context, secretPrefix []byte) (database.OAuth2ProviderAppCode, error) {
1351+
return updateWithReturn(q.log, q.auth, q.db.GetOAuth2ProviderAppCodeByPrefix, q.db.ConsumeOAuth2ProviderAppCodeByPrefix)(ctx, secretPrefix)
1352+
}
1353+
1354+
func (q *querier) ConsumeOAuth2ProviderDeviceCodeByPrefix(ctx context.Context, deviceCodePrefix string) (database.OAuth2ProviderDeviceCode, error) {
1355+
return updateWithReturn(q.log, q.auth, q.db.GetOAuth2ProviderDeviceCodeByPrefix, q.db.ConsumeOAuth2ProviderDeviceCodeByPrefix)(ctx, deviceCodePrefix)
1356+
}
1357+
13491358
func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
13501359
// Shortcut if the user is an owner. The SQL filter is noticeable,
13511360
// and this is an easy win for owners. Which is the common case.
@@ -1560,16 +1569,6 @@ func (q *querier) DeleteOAuth2ProviderAppTokensByAppAndUserID(ctx context.Contex
15601569
return q.db.DeleteOAuth2ProviderAppTokensByAppAndUserID(ctx, arg)
15611570
}
15621571

1563-
func (q *querier) DeleteOldAuditLogConnectionEvents(ctx context.Context, threshold database.DeleteOldAuditLogConnectionEventsParams) error {
1564-
// `ResourceSystem` is deprecated, but it doesn't make sense to add
1565-
// `policy.ActionDelete` to `ResourceAuditLog`, since this is the one and
1566-
// only time we'll be deleting from the audit log.
1567-
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
1568-
return err
1569-
}
1570-
return q.db.DeleteOldAuditLogConnectionEvents(ctx, threshold)
1571-
}
1572-
15731572
func (q *querier) DeleteOAuth2ProviderDeviceCodeByID(ctx context.Context, id uuid.UUID) error {
15741573
// Fetch the device code first to check authorization
15751574
deviceCode, err := q.db.GetOAuth2ProviderDeviceCodeByID(ctx, id)
@@ -1583,6 +1582,16 @@ func (q *querier) DeleteOAuth2ProviderDeviceCodeByID(ctx context.Context, id uui
15831582
return q.db.DeleteOAuth2ProviderDeviceCodeByID(ctx, id)
15841583
}
15851584

1585+
func (q *querier) DeleteOldAuditLogConnectionEvents(ctx context.Context, threshold database.DeleteOldAuditLogConnectionEventsParams) error {
1586+
// `ResourceSystem` is deprecated, but it doesn't make sense to add
1587+
// `policy.ActionDelete` to `ResourceAuditLog`, since this is the one and
1588+
// only time we'll be deleting from the audit log.
1589+
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
1590+
return err
1591+
}
1592+
return q.db.DeleteOldAuditLogConnectionEvents(ctx, threshold)
1593+
}
1594+
15861595
func (q *querier) DeleteOldNotificationMessages(ctx context.Context) error {
15871596
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceNotificationMessage); err != nil {
15881597
return err
@@ -2359,8 +2368,8 @@ func (q *querier) GetOAuth2ProviderDeviceCodeByUserCode(ctx context.Context, use
23592368
}
23602369

23612370
func (q *querier) GetOAuth2ProviderDeviceCodesByClientID(ctx context.Context, clientID uuid.UUID) ([]database.OAuth2ProviderDeviceCode, error) {
2362-
// This requires access to read the OAuth2 app
2363-
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceOauth2App); err != nil {
2371+
// This requires access to read OAuth2 app code tokens
2372+
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceOauth2AppCodeToken); err != nil {
23642373
return []database.OAuth2ProviderDeviceCode{}, err
23652374
}
23662375
return q.db.GetOAuth2ProviderDeviceCodesByClientID(ctx, clientID)
@@ -3810,8 +3819,8 @@ func (q *querier) InsertOAuth2ProviderAppToken(ctx context.Context, arg database
38103819
}
38113820

38123821
func (q *querier) InsertOAuth2ProviderDeviceCode(ctx context.Context, arg database.InsertOAuth2ProviderDeviceCodeParams) (database.OAuth2ProviderDeviceCode, error) {
3813-
// Creating device codes requires OAuth2 app access
3814-
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceOauth2App); err != nil {
3822+
// Creating device codes requires OAuth2 app code token creation access
3823+
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceOauth2AppCodeToken); err != nil {
38153824
return database.OAuth2ProviderDeviceCode{}, err
38163825
}
38173826
return q.db.InsertOAuth2ProviderDeviceCode(ctx, arg)
@@ -4490,13 +4499,10 @@ func (q *querier) UpdateOAuth2ProviderAppSecretByID(ctx context.Context, arg dat
44904499
}
44914500

44924501
func (q *querier) UpdateOAuth2ProviderDeviceCodeAuthorization(ctx context.Context, arg database.UpdateOAuth2ProviderDeviceCodeAuthorizationParams) (database.OAuth2ProviderDeviceCode, error) {
4493-
// Verify the user is authenticated for device code authorization
4494-
_, ok := ActorFromContext(ctx)
4495-
if !ok {
4496-
return database.OAuth2ProviderDeviceCode{}, ErrNoActor
4502+
fetch := func(ctx context.Context, arg database.UpdateOAuth2ProviderDeviceCodeAuthorizationParams) (database.OAuth2ProviderDeviceCode, error) {
4503+
return q.db.GetOAuth2ProviderDeviceCodeByID(ctx, arg.ID)
44974504
}
4498-
4499-
return q.db.UpdateOAuth2ProviderDeviceCodeAuthorization(ctx, arg)
4505+
return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateOAuth2ProviderDeviceCodeAuthorization)(ctx, arg)
45004506
}
45014507

45024508
func (q *querier) UpdateOrganization(ctx context.Context, arg database.UpdateOrganizationParams) (database.Organization, error) {

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5532,6 +5532,19 @@ func (s *MethodTestSuite) TestOAuth2ProviderAppCodes() {
55325532
UserID: user.ID,
55335533
}).Asserts(rbac.ResourceOauth2AppCodeToken.WithOwner(user.ID.String()), policy.ActionDelete)
55345534
}))
5535+
s.Run("ConsumeOAuth2ProviderAppCodeByPrefix", s.Subtest(func(db database.Store, check *expects) {
5536+
user := dbgen.User(s.T(), db, database.User{})
5537+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5538+
// Use unique prefix to avoid test isolation issues
5539+
uniquePrefix := fmt.Sprintf("prefix-%s-%d", s.T().Name(), time.Now().UnixNano())
5540+
code := dbgen.OAuth2ProviderAppCode(s.T(), db, database.OAuth2ProviderAppCode{
5541+
SecretPrefix: []byte(uniquePrefix),
5542+
UserID: user.ID,
5543+
AppID: app.ID,
5544+
ExpiresAt: time.Now().Add(24 * time.Hour), // Extended expiry for test stability
5545+
})
5546+
check.Args(code.SecretPrefix).Asserts(code, policy.ActionUpdate).Returns(code)
5547+
}))
55355548
}
55365549

55375550
func (s *MethodTestSuite) TestOAuth2ProviderAppTokens() {
@@ -5607,6 +5620,115 @@ func (s *MethodTestSuite) TestOAuth2ProviderAppTokens() {
56075620
}))
56085621
}
56095622

5623+
func (s *MethodTestSuite) TestOAuth2ProviderDeviceCodes() {
5624+
s.Run("InsertOAuth2ProviderDeviceCode", s.Subtest(func(db database.Store, check *expects) {
5625+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5626+
check.Args(database.InsertOAuth2ProviderDeviceCodeParams{
5627+
ClientID: app.ID,
5628+
DeviceCodePrefix: "testpref",
5629+
DeviceCodeHash: []byte("hash"),
5630+
UserCode: "TEST1234",
5631+
VerificationUri: "http://example.com/device",
5632+
}).Asserts(rbac.ResourceOauth2AppCodeToken, policy.ActionCreate)
5633+
}))
5634+
s.Run("GetOAuth2ProviderDeviceCodeByID", s.Subtest(func(db database.Store, check *expects) {
5635+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5636+
deviceCode, err := db.InsertOAuth2ProviderDeviceCode(context.Background(), database.InsertOAuth2ProviderDeviceCodeParams{
5637+
ClientID: app.ID,
5638+
DeviceCodePrefix: "testpref",
5639+
UserCode: "TEST1234",
5640+
VerificationUri: "http://example.com/device",
5641+
})
5642+
require.NoError(s.T(), err)
5643+
check.Args(deviceCode.ID).Asserts(deviceCode, policy.ActionRead).Returns(deviceCode)
5644+
}))
5645+
s.Run("GetOAuth2ProviderDeviceCodeByPrefix", s.Subtest(func(db database.Store, check *expects) {
5646+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5647+
deviceCode, err := db.InsertOAuth2ProviderDeviceCode(context.Background(), database.InsertOAuth2ProviderDeviceCodeParams{
5648+
ClientID: app.ID,
5649+
DeviceCodePrefix: "testpref",
5650+
UserCode: "TEST1234",
5651+
VerificationUri: "http://example.com/device",
5652+
})
5653+
require.NoError(s.T(), err)
5654+
check.Args(deviceCode.DeviceCodePrefix).Asserts(deviceCode, policy.ActionRead).Returns(deviceCode)
5655+
}))
5656+
s.Run("GetOAuth2ProviderDeviceCodeByUserCode", s.Subtest(func(db database.Store, check *expects) {
5657+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5658+
deviceCode, err := db.InsertOAuth2ProviderDeviceCode(context.Background(), database.InsertOAuth2ProviderDeviceCodeParams{
5659+
ClientID: app.ID,
5660+
DeviceCodePrefix: "testpref",
5661+
UserCode: "TEST1234",
5662+
VerificationUri: "http://example.com/device",
5663+
})
5664+
require.NoError(s.T(), err)
5665+
check.Args(deviceCode.UserCode).Asserts(deviceCode, policy.ActionRead).Returns(deviceCode)
5666+
}))
5667+
s.Run("GetOAuth2ProviderDeviceCodesByClientID", s.Subtest(func(db database.Store, check *expects) {
5668+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5669+
deviceCode, err := db.InsertOAuth2ProviderDeviceCode(context.Background(), database.InsertOAuth2ProviderDeviceCodeParams{
5670+
ClientID: app.ID,
5671+
DeviceCodePrefix: "testpref",
5672+
UserCode: "TEST1234",
5673+
VerificationUri: "http://example.com/device",
5674+
})
5675+
require.NoError(s.T(), err)
5676+
check.Args(app.ID).Asserts(rbac.ResourceOauth2AppCodeToken, policy.ActionRead).Returns([]database.OAuth2ProviderDeviceCode{deviceCode})
5677+
}))
5678+
s.Run("ConsumeOAuth2ProviderDeviceCodeByPrefix", s.Subtest(func(db database.Store, check *expects) {
5679+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5680+
user := dbgen.User(s.T(), db, database.User{})
5681+
// Use unique identifiers to avoid test isolation issues
5682+
// Device code prefix must be exactly 8 characters
5683+
uniquePrefix := fmt.Sprintf("t%07d", time.Now().UnixNano()%10000000)
5684+
uniqueUserCode := fmt.Sprintf("USER%04d", time.Now().UnixNano()%10000)
5685+
// Create device code using dbgen (now available!)
5686+
deviceCode := dbgen.OAuth2ProviderDeviceCode(s.T(), db, database.OAuth2ProviderDeviceCode{
5687+
DeviceCodePrefix: uniquePrefix,
5688+
UserCode: uniqueUserCode,
5689+
ClientID: app.ID,
5690+
ExpiresAt: time.Now().Add(24 * time.Hour), // Extended expiry for test stability
5691+
})
5692+
// Authorize the device code so it can be consumed
5693+
deviceCode, err := db.UpdateOAuth2ProviderDeviceCodeAuthorization(s.T().Context(), database.UpdateOAuth2ProviderDeviceCodeAuthorizationParams{
5694+
ID: deviceCode.ID,
5695+
UserID: uuid.NullUUID{UUID: user.ID, Valid: true},
5696+
Status: database.OAuth2DeviceStatusAuthorized,
5697+
})
5698+
require.NoError(s.T(), err)
5699+
require.Equal(s.T(), database.OAuth2DeviceStatusAuthorized, deviceCode.Status)
5700+
check.Args(uniquePrefix).Asserts(deviceCode, policy.ActionUpdate).Returns(deviceCode)
5701+
}))
5702+
s.Run("UpdateOAuth2ProviderDeviceCodeAuthorization", s.Subtest(func(db database.Store, check *expects) {
5703+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5704+
user := dbgen.User(s.T(), db, database.User{})
5705+
// Create device code using dbgen
5706+
deviceCode := dbgen.OAuth2ProviderDeviceCode(s.T(), db, database.OAuth2ProviderDeviceCode{
5707+
ClientID: app.ID,
5708+
})
5709+
require.Equal(s.T(), database.OAuth2DeviceStatusPending, deviceCode.Status)
5710+
check.Args(database.UpdateOAuth2ProviderDeviceCodeAuthorizationParams{
5711+
ID: deviceCode.ID,
5712+
UserID: uuid.NullUUID{UUID: user.ID, Valid: true},
5713+
Status: database.OAuth2DeviceStatusAuthorized,
5714+
}).Asserts(deviceCode, policy.ActionUpdate)
5715+
}))
5716+
s.Run("DeleteOAuth2ProviderDeviceCodeByID", s.Subtest(func(db database.Store, check *expects) {
5717+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5718+
deviceCode, err := db.InsertOAuth2ProviderDeviceCode(context.Background(), database.InsertOAuth2ProviderDeviceCodeParams{
5719+
ClientID: app.ID,
5720+
DeviceCodePrefix: "testpref",
5721+
UserCode: "TEST1234",
5722+
VerificationUri: "http://example.com/device",
5723+
})
5724+
require.NoError(s.T(), err)
5725+
check.Args(deviceCode.ID).Asserts(deviceCode, policy.ActionDelete)
5726+
}))
5727+
s.Run("DeleteExpiredOAuth2ProviderDeviceCodes", s.Subtest(func(db database.Store, check *expects) {
5728+
check.Args().Asserts(rbac.ResourceSystem, policy.ActionDelete)
5729+
}))
5730+
}
5731+
56105732
func (s *MethodTestSuite) TestResourcesMonitor() {
56115733
createAgent := func(t *testing.T, db database.Store) (database.WorkspaceAgent, database.WorkspaceTable) {
56125734
t.Helper()

coderd/database/dbgen/dbgen.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,7 +1246,7 @@ func OAuth2ProviderAppCode(t testing.TB, db database.Store, seed database.OAuth2
12461246
code, err := db.InsertOAuth2ProviderAppCode(genCtx, database.InsertOAuth2ProviderAppCodeParams{
12471247
ID: takeFirst(seed.ID, uuid.New()),
12481248
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
1249-
ExpiresAt: takeFirst(seed.CreatedAt, dbtime.Now()),
1249+
ExpiresAt: takeFirst(seed.ExpiresAt, dbtime.Now().Add(24*time.Hour)),
12501250
SecretPrefix: takeFirstSlice(seed.SecretPrefix, []byte("prefix")),
12511251
HashedSecret: takeFirstSlice(seed.HashedSecret, []byte("hashed-secret")),
12521252
AppID: takeFirst(seed.AppID, uuid.New()),
@@ -1263,7 +1263,7 @@ func OAuth2ProviderAppToken(t testing.TB, db database.Store, seed database.OAuth
12631263
token, err := db.InsertOAuth2ProviderAppToken(genCtx, database.InsertOAuth2ProviderAppTokenParams{
12641264
ID: takeFirst(seed.ID, uuid.New()),
12651265
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
1266-
ExpiresAt: takeFirst(seed.CreatedAt, dbtime.Now()),
1266+
ExpiresAt: takeFirst(seed.ExpiresAt, dbtime.Now().Add(24*time.Hour)),
12671267
HashPrefix: takeFirstSlice(seed.HashPrefix, []byte("prefix")),
12681268
RefreshHash: takeFirstSlice(seed.RefreshHash, []byte("hashed-secret")),
12691269
AppSecretID: takeFirst(seed.AppSecretID, uuid.New()),
@@ -1275,6 +1275,25 @@ func OAuth2ProviderAppToken(t testing.TB, db database.Store, seed database.OAuth
12751275
return token
12761276
}
12771277

1278+
func OAuth2ProviderDeviceCode(t testing.TB, db database.Store, seed database.OAuth2ProviderDeviceCode) database.OAuth2ProviderDeviceCode {
1279+
deviceCode, err := db.InsertOAuth2ProviderDeviceCode(genCtx, database.InsertOAuth2ProviderDeviceCodeParams{
1280+
ID: takeFirst(seed.ID, uuid.New()),
1281+
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
1282+
ExpiresAt: takeFirst(seed.ExpiresAt, dbtime.Now().Add(24*time.Hour)),
1283+
DeviceCodeHash: takeFirstSlice(seed.DeviceCodeHash, []byte("device-hash")),
1284+
DeviceCodePrefix: takeFirst(seed.DeviceCodePrefix, testutil.GetRandomName(t)[:8]),
1285+
UserCode: takeFirst(seed.UserCode, testutil.GetRandomName(t)),
1286+
ClientID: takeFirst(seed.ClientID, uuid.New()),
1287+
VerificationUri: takeFirst(seed.VerificationUri, "https://example.com/device"),
1288+
VerificationUriComplete: seed.VerificationUriComplete,
1289+
Scope: seed.Scope,
1290+
ResourceUri: seed.ResourceUri,
1291+
PollingInterval: takeFirst(seed.PollingInterval, 5),
1292+
})
1293+
require.NoError(t, err, "insert oauth2 device code")
1294+
return deviceCode
1295+
}
1296+
12781297
func WorkspaceAgentMemoryResourceMonitor(t testing.TB, db database.Store, seed database.WorkspaceAgentMemoryResourceMonitor) database.WorkspaceAgentMemoryResourceMonitor {
12791298
monitor, err := db.InsertMemoryResourceMonitor(genCtx, database.InsertMemoryResourceMonitorParams{
12801299
AgentID: takeFirst(seed.AgentID, uuid.New()),

coderd/database/dbmetrics/querymetrics.go

Lines changed: 20 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)