Skip to content

Commit 386d77d

Browse files
committed
chore: add OAuth2 device flow test scripts
Change-Id: Ic232851727e683ab3d8b7ce970c505588da2f827 Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent 18c8ebf commit 386d77d

File tree

24 files changed

+794
-104
lines changed

24 files changed

+794
-104
lines changed

.claude/scripts/format.sh

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,30 +101,36 @@ fi
101101
# Get the file extension to determine the appropriate formatter
102102
file_ext="${file_path##*.}"
103103

104+
# Helper function to run formatter and handle errors
105+
run_formatter() {
106+
local target="$1"
107+
local file_type="$2"
108+
109+
if ! make FILE="$file_path" "$target"; then
110+
echo "Error: Failed to format $file_type file: $file_path" >&2
111+
exit 2
112+
fi
113+
echo "✓ Formatted $file_type file: $file_path"
114+
}
104115
# Change to the project root directory (where the Makefile is located)
105116
cd "$(dirname "$0")/../.."
106117

107118
# Call the appropriate Makefile target based on file extension
108119
case "$file_ext" in
109120
go)
110-
make fmt/go FILE="$file_path"
111-
echo "✓ Formatted Go file: $file_path"
121+
run_formatter "fmt/go" "Go"
112122
;;
113123
js | jsx | ts | tsx)
114-
make fmt/ts FILE="$file_path"
115-
echo "✓ Formatted TypeScript/JavaScript file: $file_path"
124+
run_formatter "fmt/ts" "TypeScript/JavaScript"
116125
;;
117126
tf | tfvars)
118-
make fmt/terraform FILE="$file_path"
119-
echo "✓ Formatted Terraform file: $file_path"
127+
run_formatter "fmt/terraform" "Terraform"
120128
;;
121129
sh)
122-
make fmt/shfmt FILE="$file_path"
123-
echo "✓ Formatted shell script: $file_path"
130+
run_formatter "fmt/shfmt" "shell script"
124131
;;
125132
md)
126-
make fmt/markdown FILE="$file_path"
127-
echo "✓ Formatted Markdown file: $file_path"
133+
run_formatter "fmt/markdown" "Markdown"
128134
;;
129135
*)
130136
echo "No formatter available for file extension: $file_ext"

coderd/coderd.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ func New(options *Options) *API {
974974
r.Route("/device", func(r chi.Router) {
975975
r.Post("/", api.postOAuth2DeviceAuthorization()) // RFC 8628 compliant endpoint
976976
r.Route("/verify", func(r chi.Router) {
977-
r.Use(apiKeyMiddleware)
977+
r.Use(apiKeyMiddlewareRedirect)
978978
r.Get("/", api.getOAuth2DeviceVerification())
979979
r.Post("/", api.postOAuth2DeviceVerification())
980980
})

coderd/database/dbauthz/dbauthz.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ var (
399399
rbac.ResourceProvisionerJobs.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreate},
400400
rbac.ResourceOauth2App.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
401401
rbac.ResourceOauth2AppSecret.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
402+
rbac.ResourceOauth2AppCodeToken.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
402403
}),
403404
Org: map[string][]rbac.Permission{},
404405
User: []rbac.Permission{},
@@ -1324,6 +1325,14 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error {
13241325
return q.db.CleanTailnetTunnels(ctx)
13251326
}
13261327

1328+
func (q *querier) ConsumeOAuth2ProviderAppCodeByPrefix(ctx context.Context, secretPrefix []byte) (database.OAuth2ProviderAppCode, error) {
1329+
return updateWithReturn(q.log, q.auth, q.db.GetOAuth2ProviderAppCodeByPrefix, q.db.ConsumeOAuth2ProviderAppCodeByPrefix)(ctx, secretPrefix)
1330+
}
1331+
1332+
func (q *querier) ConsumeOAuth2ProviderDeviceCodeByPrefix(ctx context.Context, deviceCodePrefix string) (database.OAuth2ProviderDeviceCode, error) {
1333+
return updateWithReturn(q.log, q.auth, q.db.GetOAuth2ProviderDeviceCodeByPrefix, q.db.ConsumeOAuth2ProviderDeviceCodeByPrefix)(ctx, deviceCodePrefix)
1334+
}
1335+
13271336
func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
13281337
// Shortcut if the user is an owner. The SQL filter is noticeable,
13291338
// and this is an easy win for owners. Which is the common case.
@@ -2301,8 +2310,8 @@ func (q *querier) GetOAuth2ProviderDeviceCodeByUserCode(ctx context.Context, use
23012310
}
23022311

23032312
func (q *querier) GetOAuth2ProviderDeviceCodesByClientID(ctx context.Context, clientID uuid.UUID) ([]database.OAuth2ProviderDeviceCode, error) {
2304-
// This requires access to read the OAuth2 app
2305-
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceOauth2App); err != nil {
2313+
// This requires access to read OAuth2 app code tokens
2314+
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceOauth2AppCodeToken); err != nil {
23062315
return []database.OAuth2ProviderDeviceCode{}, err
23072316
}
23082317
return q.db.GetOAuth2ProviderDeviceCodesByClientID(ctx, clientID)
@@ -3752,8 +3761,8 @@ func (q *querier) InsertOAuth2ProviderAppToken(ctx context.Context, arg database
37523761
}
37533762

37543763
func (q *querier) InsertOAuth2ProviderDeviceCode(ctx context.Context, arg database.InsertOAuth2ProviderDeviceCodeParams) (database.OAuth2ProviderDeviceCode, error) {
3755-
// Creating device codes requires OAuth2 app access
3756-
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceOauth2App); err != nil {
3764+
// Creating device codes requires OAuth2 app code token creation access
3765+
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceOauth2AppCodeToken); err != nil {
37573766
return database.OAuth2ProviderDeviceCode{}, err
37583767
}
37593768
return q.db.InsertOAuth2ProviderDeviceCode(ctx, arg)
@@ -4432,13 +4441,10 @@ func (q *querier) UpdateOAuth2ProviderAppSecretByID(ctx context.Context, arg dat
44324441
}
44334442

44344443
func (q *querier) UpdateOAuth2ProviderDeviceCodeAuthorization(ctx context.Context, arg database.UpdateOAuth2ProviderDeviceCodeAuthorizationParams) (database.OAuth2ProviderDeviceCode, error) {
4435-
// Verify the user is authenticated for device code authorization
4436-
_, ok := ActorFromContext(ctx)
4437-
if !ok {
4438-
return database.OAuth2ProviderDeviceCode{}, ErrNoActor
4444+
fetch := func(ctx context.Context, arg database.UpdateOAuth2ProviderDeviceCodeAuthorizationParams) (database.OAuth2ProviderDeviceCode, error) {
4445+
return q.db.GetOAuth2ProviderDeviceCodeByID(ctx, arg.ID)
44394446
}
4440-
4441-
return q.db.UpdateOAuth2ProviderDeviceCodeAuthorization(ctx, arg)
4447+
return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateOAuth2ProviderDeviceCodeAuthorization)(ctx, arg)
44424448
}
44434449

44444450
func (q *querier) UpdateOrganization(ctx context.Context, arg database.UpdateOrganizationParams) (database.Organization, error) {

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5423,6 +5423,19 @@ func (s *MethodTestSuite) TestOAuth2ProviderAppCodes() {
54235423
UserID: user.ID,
54245424
}).Asserts(rbac.ResourceOauth2AppCodeToken.WithOwner(user.ID.String()), policy.ActionDelete)
54255425
}))
5426+
s.Run("ConsumeOAuth2ProviderAppCodeByPrefix", s.Subtest(func(db database.Store, check *expects) {
5427+
user := dbgen.User(s.T(), db, database.User{})
5428+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5429+
// Use unique prefix to avoid test isolation issues
5430+
uniquePrefix := fmt.Sprintf("prefix-%s-%d", s.T().Name(), time.Now().UnixNano())
5431+
code := dbgen.OAuth2ProviderAppCode(s.T(), db, database.OAuth2ProviderAppCode{
5432+
SecretPrefix: []byte(uniquePrefix),
5433+
UserID: user.ID,
5434+
AppID: app.ID,
5435+
ExpiresAt: time.Now().Add(24 * time.Hour), // Extended expiry for test stability
5436+
})
5437+
check.Args(code.SecretPrefix).Asserts(code, policy.ActionUpdate).Returns(code)
5438+
}))
54265439
}
54275440

54285441
func (s *MethodTestSuite) TestOAuth2ProviderAppTokens() {
@@ -5498,6 +5511,115 @@ func (s *MethodTestSuite) TestOAuth2ProviderAppTokens() {
54985511
}))
54995512
}
55005513

5514+
func (s *MethodTestSuite) TestOAuth2ProviderDeviceCodes() {
5515+
s.Run("InsertOAuth2ProviderDeviceCode", s.Subtest(func(db database.Store, check *expects) {
5516+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5517+
check.Args(database.InsertOAuth2ProviderDeviceCodeParams{
5518+
ClientID: app.ID,
5519+
DeviceCodePrefix: "testpref",
5520+
DeviceCodeHash: []byte("hash"),
5521+
UserCode: "TEST1234",
5522+
VerificationUri: "http://example.com/device",
5523+
}).Asserts(rbac.ResourceOauth2AppCodeToken, policy.ActionCreate)
5524+
}))
5525+
s.Run("GetOAuth2ProviderDeviceCodeByID", s.Subtest(func(db database.Store, check *expects) {
5526+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5527+
deviceCode, err := db.InsertOAuth2ProviderDeviceCode(context.Background(), database.InsertOAuth2ProviderDeviceCodeParams{
5528+
ClientID: app.ID,
5529+
DeviceCodePrefix: "testpref",
5530+
UserCode: "TEST1234",
5531+
VerificationUri: "http://example.com/device",
5532+
})
5533+
require.NoError(s.T(), err)
5534+
check.Args(deviceCode.ID).Asserts(deviceCode, policy.ActionRead).Returns(deviceCode)
5535+
}))
5536+
s.Run("GetOAuth2ProviderDeviceCodeByPrefix", s.Subtest(func(db database.Store, check *expects) {
5537+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5538+
deviceCode, err := db.InsertOAuth2ProviderDeviceCode(context.Background(), database.InsertOAuth2ProviderDeviceCodeParams{
5539+
ClientID: app.ID,
5540+
DeviceCodePrefix: "testpref",
5541+
UserCode: "TEST1234",
5542+
VerificationUri: "http://example.com/device",
5543+
})
5544+
require.NoError(s.T(), err)
5545+
check.Args(deviceCode.DeviceCodePrefix).Asserts(deviceCode, policy.ActionRead).Returns(deviceCode)
5546+
}))
5547+
s.Run("GetOAuth2ProviderDeviceCodeByUserCode", s.Subtest(func(db database.Store, check *expects) {
5548+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5549+
deviceCode, err := db.InsertOAuth2ProviderDeviceCode(context.Background(), database.InsertOAuth2ProviderDeviceCodeParams{
5550+
ClientID: app.ID,
5551+
DeviceCodePrefix: "testpref",
5552+
UserCode: "TEST1234",
5553+
VerificationUri: "http://example.com/device",
5554+
})
5555+
require.NoError(s.T(), err)
5556+
check.Args(deviceCode.UserCode).Asserts(deviceCode, policy.ActionRead).Returns(deviceCode)
5557+
}))
5558+
s.Run("GetOAuth2ProviderDeviceCodesByClientID", s.Subtest(func(db database.Store, check *expects) {
5559+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5560+
deviceCode, err := db.InsertOAuth2ProviderDeviceCode(context.Background(), database.InsertOAuth2ProviderDeviceCodeParams{
5561+
ClientID: app.ID,
5562+
DeviceCodePrefix: "testpref",
5563+
UserCode: "TEST1234",
5564+
VerificationUri: "http://example.com/device",
5565+
})
5566+
require.NoError(s.T(), err)
5567+
check.Args(app.ID).Asserts(rbac.ResourceOauth2AppCodeToken, policy.ActionRead).Returns([]database.OAuth2ProviderDeviceCode{deviceCode})
5568+
}))
5569+
s.Run("ConsumeOAuth2ProviderDeviceCodeByPrefix", s.Subtest(func(db database.Store, check *expects) {
5570+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5571+
user := dbgen.User(s.T(), db, database.User{})
5572+
// Use unique identifiers to avoid test isolation issues
5573+
// Device code prefix must be exactly 8 characters
5574+
uniquePrefix := fmt.Sprintf("t%07d", time.Now().UnixNano()%10000000)
5575+
uniqueUserCode := fmt.Sprintf("USER%04d", time.Now().UnixNano()%10000)
5576+
// Create device code using dbgen (now available!)
5577+
deviceCode := dbgen.OAuth2ProviderDeviceCode(s.T(), db, database.OAuth2ProviderDeviceCode{
5578+
DeviceCodePrefix: uniquePrefix,
5579+
UserCode: uniqueUserCode,
5580+
ClientID: app.ID,
5581+
ExpiresAt: time.Now().Add(24 * time.Hour), // Extended expiry for test stability
5582+
})
5583+
// Authorize the device code so it can be consumed
5584+
deviceCode, err := db.UpdateOAuth2ProviderDeviceCodeAuthorization(s.T().Context(), database.UpdateOAuth2ProviderDeviceCodeAuthorizationParams{
5585+
ID: deviceCode.ID,
5586+
UserID: uuid.NullUUID{UUID: user.ID, Valid: true},
5587+
Status: database.OAuth2DeviceStatusAuthorized,
5588+
})
5589+
require.NoError(s.T(), err)
5590+
require.Equal(s.T(), database.OAuth2DeviceStatusAuthorized, deviceCode.Status)
5591+
check.Args(uniquePrefix).Asserts(deviceCode, policy.ActionUpdate).Returns(deviceCode)
5592+
}))
5593+
s.Run("UpdateOAuth2ProviderDeviceCodeAuthorization", s.Subtest(func(db database.Store, check *expects) {
5594+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5595+
user := dbgen.User(s.T(), db, database.User{})
5596+
// Create device code using dbgen
5597+
deviceCode := dbgen.OAuth2ProviderDeviceCode(s.T(), db, database.OAuth2ProviderDeviceCode{
5598+
ClientID: app.ID,
5599+
})
5600+
require.Equal(s.T(), database.OAuth2DeviceStatusPending, deviceCode.Status)
5601+
check.Args(database.UpdateOAuth2ProviderDeviceCodeAuthorizationParams{
5602+
ID: deviceCode.ID,
5603+
UserID: uuid.NullUUID{UUID: user.ID, Valid: true},
5604+
Status: database.OAuth2DeviceStatusAuthorized,
5605+
}).Asserts(deviceCode, policy.ActionUpdate)
5606+
}))
5607+
s.Run("DeleteOAuth2ProviderDeviceCodeByID", s.Subtest(func(db database.Store, check *expects) {
5608+
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
5609+
deviceCode, err := db.InsertOAuth2ProviderDeviceCode(context.Background(), database.InsertOAuth2ProviderDeviceCodeParams{
5610+
ClientID: app.ID,
5611+
DeviceCodePrefix: "testpref",
5612+
UserCode: "TEST1234",
5613+
VerificationUri: "http://example.com/device",
5614+
})
5615+
require.NoError(s.T(), err)
5616+
check.Args(deviceCode.ID).Asserts(deviceCode, policy.ActionDelete)
5617+
}))
5618+
s.Run("DeleteExpiredOAuth2ProviderDeviceCodes", s.Subtest(func(db database.Store, check *expects) {
5619+
check.Args().Asserts(rbac.ResourceSystem, policy.ActionDelete)
5620+
}))
5621+
}
5622+
55015623
func (s *MethodTestSuite) TestResourcesMonitor() {
55025624
createAgent := func(t *testing.T, db database.Store) (database.WorkspaceAgent, database.WorkspaceTable) {
55035625
t.Helper()

coderd/database/dbgen/dbgen.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,7 +1199,7 @@ func OAuth2ProviderAppCode(t testing.TB, db database.Store, seed database.OAuth2
11991199
code, err := db.InsertOAuth2ProviderAppCode(genCtx, database.InsertOAuth2ProviderAppCodeParams{
12001200
ID: takeFirst(seed.ID, uuid.New()),
12011201
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
1202-
ExpiresAt: takeFirst(seed.CreatedAt, dbtime.Now()),
1202+
ExpiresAt: takeFirst(seed.ExpiresAt, dbtime.Now().Add(24*time.Hour)),
12031203
SecretPrefix: takeFirstSlice(seed.SecretPrefix, []byte("prefix")),
12041204
HashedSecret: takeFirstSlice(seed.HashedSecret, []byte("hashed-secret")),
12051205
AppID: takeFirst(seed.AppID, uuid.New()),
@@ -1216,7 +1216,7 @@ func OAuth2ProviderAppToken(t testing.TB, db database.Store, seed database.OAuth
12161216
token, err := db.InsertOAuth2ProviderAppToken(genCtx, database.InsertOAuth2ProviderAppTokenParams{
12171217
ID: takeFirst(seed.ID, uuid.New()),
12181218
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
1219-
ExpiresAt: takeFirst(seed.CreatedAt, dbtime.Now()),
1219+
ExpiresAt: takeFirst(seed.ExpiresAt, dbtime.Now().Add(24*time.Hour)),
12201220
HashPrefix: takeFirstSlice(seed.HashPrefix, []byte("prefix")),
12211221
RefreshHash: takeFirstSlice(seed.RefreshHash, []byte("hashed-secret")),
12221222
AppSecretID: takeFirst(seed.AppSecretID, uuid.New()),
@@ -1228,6 +1228,25 @@ func OAuth2ProviderAppToken(t testing.TB, db database.Store, seed database.OAuth
12281228
return token
12291229
}
12301230

1231+
func OAuth2ProviderDeviceCode(t testing.TB, db database.Store, seed database.OAuth2ProviderDeviceCode) database.OAuth2ProviderDeviceCode {
1232+
deviceCode, err := db.InsertOAuth2ProviderDeviceCode(genCtx, database.InsertOAuth2ProviderDeviceCodeParams{
1233+
ID: takeFirst(seed.ID, uuid.New()),
1234+
CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()),
1235+
ExpiresAt: takeFirst(seed.ExpiresAt, dbtime.Now().Add(24*time.Hour)),
1236+
DeviceCodeHash: takeFirstSlice(seed.DeviceCodeHash, []byte("device-hash")),
1237+
DeviceCodePrefix: takeFirst(seed.DeviceCodePrefix, testutil.GetRandomName(t)[:8]),
1238+
UserCode: takeFirst(seed.UserCode, testutil.GetRandomName(t)),
1239+
ClientID: takeFirst(seed.ClientID, uuid.New()),
1240+
VerificationUri: takeFirst(seed.VerificationUri, "https://example.com/device"),
1241+
VerificationUriComplete: seed.VerificationUriComplete,
1242+
Scope: seed.Scope,
1243+
ResourceUri: seed.ResourceUri,
1244+
PollingInterval: takeFirst(seed.PollingInterval, 5),
1245+
})
1246+
require.NoError(t, err, "insert oauth2 device code")
1247+
return deviceCode
1248+
}
1249+
12311250
func WorkspaceAgentMemoryResourceMonitor(t testing.TB, db database.Store, seed database.WorkspaceAgentMemoryResourceMonitor) database.WorkspaceAgentMemoryResourceMonitor {
12321251
monitor, err := db.InsertMemoryResourceMonitor(genCtx, database.InsertMemoryResourceMonitorParams{
12331252
AgentID: takeFirst(seed.AgentID, uuid.New()),

coderd/database/dbmetrics/querymetrics.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 30 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dump.sql

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)