Skip to content

Commit 071383b

Browse files
authored
feat: add RFC 9728 OAuth2 resource metadata support (#18920)
# Enhanced OAuth2 and MCP Compliance for API Authentication This PR improves OAuth2 and MCP (Microsoft Cloud for Sovereignty) compliance by: 1. Adding RFC 9728 compliant `WWW-Authenticate` headers with resource metadata URLs 2. Passing the configured `AccessURL` to API key middleware for proper audience validation 3. Creating specialized CORS handling for OAuth2 and MCP endpoints with appropriate headers 4. Making the `state` parameter optional in OAuth2 authorization requests These changes ensure proper OAuth2 token audience validation against the configured access URL and improve interoperability with OAuth2 clients by providing better error responses and metadata discovery. Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent f47efc6 commit 071383b

File tree

6 files changed

+116
-39
lines changed

6 files changed

+116
-39
lines changed

coderd/coderd.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,7 @@ func New(options *Options) *API {
790790
SessionTokenFunc: nil, // Default behavior
791791
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
792792
Logger: options.Logger,
793+
AccessURL: options.AccessURL,
793794
})
794795
// Same as above but it redirects to the login page.
795796
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
@@ -801,6 +802,7 @@ func New(options *Options) *API {
801802
SessionTokenFunc: nil, // Default behavior
802803
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
803804
Logger: options.Logger,
805+
AccessURL: options.AccessURL,
804806
})
805807
// Same as the first but it's optional.
806808
apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
@@ -812,6 +814,7 @@ func New(options *Options) *API {
812814
SessionTokenFunc: nil, // Default behavior
813815
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
814816
Logger: options.Logger,
817+
AccessURL: options.AccessURL,
815818
})
816819

817820
workspaceAgentInfo := httpmw.ExtractWorkspaceAgentAndLatestBuild(httpmw.ExtractWorkspaceAgentAndLatestBuildConfig{

coderd/httpmw/apikey.go

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ type ExtractAPIKeyConfig struct {
113113
// a user is authenticated to prevent additional CLI invocations.
114114
PostAuthAdditionalHeadersFunc func(a rbac.Subject, header http.Header)
115115

116+
// AccessURL is the configured access URL for this Coder deployment.
117+
// Used for generating OAuth2 resource metadata URLs in WWW-Authenticate headers.
118+
AccessURL *url.URL
119+
116120
// Logger is used for logging middleware operations.
117121
Logger slog.Logger
118122
}
@@ -214,29 +218,9 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
214218
return nil, nil, false
215219
}
216220

217-
// Add WWW-Authenticate header for 401/403 responses (RFC 6750)
221+
// Add WWW-Authenticate header for 401/403 responses (RFC 6750 + RFC 9728)
218222
if code == http.StatusUnauthorized || code == http.StatusForbidden {
219-
var wwwAuth string
220-
221-
switch code {
222-
case http.StatusUnauthorized:
223-
// Map 401 to invalid_token with specific error descriptions
224-
switch {
225-
case strings.Contains(response.Message, "expired") || strings.Contains(response.Detail, "expired"):
226-
wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token has expired"`
227-
case strings.Contains(response.Message, "audience") || strings.Contains(response.Message, "mismatch"):
228-
wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token audience does not match this resource"`
229-
default:
230-
wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token is invalid"`
231-
}
232-
case http.StatusForbidden:
233-
// Map 403 to insufficient_scope per RFC 6750
234-
wwwAuth = `Bearer realm="coder", error="insufficient_scope", error_description="The request requires higher privileges than provided by the access token"`
235-
default:
236-
wwwAuth = `Bearer realm="coder"`
237-
}
238-
239-
rw.Header().Set("WWW-Authenticate", wwwAuth)
223+
rw.Header().Set("WWW-Authenticate", buildWWWAuthenticateHeader(cfg.AccessURL, r, code, response))
240224
}
241225

242226
httpapi.Write(ctx, rw, code, response)
@@ -272,7 +256,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
272256

273257
// Validate OAuth2 provider app token audience (RFC 8707) if applicable
274258
if key.LoginType == database.LoginTypeOAuth2ProviderApp {
275-
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, r); err != nil {
259+
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, cfg.AccessURL, r); err != nil {
276260
// Log the detailed error for debugging but don't expose it to the client
277261
cfg.Logger.Debug(ctx, "oauth2 token audience validation failed", slog.Error(err))
278262
return optionalWrite(http.StatusForbidden, codersdk.Response{
@@ -489,7 +473,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
489473

490474
// validateOAuth2ProviderAppTokenAudience validates that an OAuth2 provider app token
491475
// is being used with the correct audience/resource server (RFC 8707).
492-
func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Store, key database.APIKey, r *http.Request) error {
476+
func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Store, key database.APIKey, accessURL *url.URL, r *http.Request) error {
493477
// Get the OAuth2 provider app token to check its audience
494478
//nolint:gocritic // System needs to access token for audience validation
495479
token, err := db.GetOAuth2ProviderAppTokenByAPIKeyID(dbauthz.AsSystemRestricted(ctx), key.ID)
@@ -502,8 +486,8 @@ func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Sto
502486
return nil
503487
}
504488

505-
// Extract the expected audience from the request
506-
expectedAudience := extractExpectedAudience(r)
489+
// Extract the expected audience from the access URL
490+
expectedAudience := extractExpectedAudience(accessURL, r)
507491

508492
// Normalize both audience values for RFC 3986 compliant comparison
509493
normalizedTokenAudience := normalizeAudienceURI(token.Audience.String)
@@ -624,18 +608,59 @@ func normalizePathSegments(path string) string {
624608

625609
// Test export functions for testing package access
626610

611+
// buildWWWAuthenticateHeader constructs RFC 6750 + RFC 9728 compliant WWW-Authenticate header
612+
func buildWWWAuthenticateHeader(accessURL *url.URL, r *http.Request, code int, response codersdk.Response) string {
613+
// Use the configured access URL for resource metadata
614+
if accessURL == nil {
615+
scheme := "https"
616+
if r.TLS == nil {
617+
scheme = "http"
618+
}
619+
620+
// Use the Host header to construct the canonical audience URI
621+
accessURL = &url.URL{
622+
Scheme: scheme,
623+
Host: r.Host,
624+
}
625+
}
626+
627+
resourceMetadata := accessURL.JoinPath("/.well-known/oauth-protected-resource").String()
628+
629+
switch code {
630+
case http.StatusUnauthorized:
631+
switch {
632+
case strings.Contains(response.Message, "expired") || strings.Contains(response.Detail, "expired"):
633+
return fmt.Sprintf(`Bearer realm="coder", error="invalid_token", error_description="The access token has expired", resource_metadata=%q`, resourceMetadata)
634+
case strings.Contains(response.Message, "audience") || strings.Contains(response.Message, "mismatch"):
635+
return fmt.Sprintf(`Bearer realm="coder", error="invalid_token", error_description="The access token audience does not match this resource", resource_metadata=%q`, resourceMetadata)
636+
default:
637+
return fmt.Sprintf(`Bearer realm="coder", error="invalid_token", error_description="The access token is invalid", resource_metadata=%q`, resourceMetadata)
638+
}
639+
case http.StatusForbidden:
640+
return fmt.Sprintf(`Bearer realm="coder", error="insufficient_scope", error_description="The request requires higher privileges than provided by the access token", resource_metadata=%q`, resourceMetadata)
641+
default:
642+
return fmt.Sprintf(`Bearer realm="coder", resource_metadata=%q`, resourceMetadata)
643+
}
644+
}
645+
627646
// extractExpectedAudience determines the expected audience for the current request.
628647
// This should match the resource parameter used during authorization.
629-
func extractExpectedAudience(r *http.Request) string {
648+
func extractExpectedAudience(accessURL *url.URL, r *http.Request) string {
630649
// For MCP compliance, the audience should be the canonical URI of the resource server
631650
// This typically matches the access URL of the Coder deployment
632-
scheme := "https"
633-
if r.TLS == nil {
634-
scheme = "http"
635-
}
651+
var audience string
652+
653+
if accessURL != nil {
654+
audience = accessURL.String()
655+
} else {
656+
scheme := "https"
657+
if r.TLS == nil {
658+
scheme = "http"
659+
}
636660

637-
// Use the Host header to construct the canonical audience URI
638-
audience := fmt.Sprintf("%s://%s", scheme, r.Host)
661+
// Use the Host header to construct the canonical audience URI
662+
audience = fmt.Sprintf("%s://%s", scheme, r.Host)
663+
}
639664

640665
// Normalize the URI according to RFC 3986 for consistent comparison
641666
return normalizeAudienceURI(audience)

coderd/httpmw/cors.go

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"net/http"
55
"net/url"
66
"regexp"
7+
"strings"
78

89
"github.com/go-chi/cors"
910

@@ -28,20 +29,66 @@ const (
2829
func Cors(allowAll bool, origins ...string) func(next http.Handler) http.Handler {
2930
if len(origins) == 0 {
3031
// The default behavior is '*', so putting the empty string defaults to
31-
// the secure behavior of blocking CORs requests.
32+
// the secure behavior of blocking CORS requests.
3233
origins = []string{""}
3334
}
3435
if allowAll {
3536
origins = []string{"*"}
3637
}
37-
return cors.Handler(cors.Options{
38+
39+
// Standard CORS for most endpoints
40+
standardCors := cors.Handler(cors.Options{
3841
AllowedOrigins: origins,
3942
// We only need GET for latency requests
4043
AllowedMethods: []string{http.MethodOptions, http.MethodGet},
4144
AllowedHeaders: []string{"Accept", "Content-Type", "X-LATENCY-CHECK", "X-CSRF-TOKEN"},
4245
// Do not send any cookies
4346
AllowCredentials: false,
4447
})
48+
49+
// Permissive CORS for OAuth2 and MCP endpoints
50+
permissiveCors := cors.Handler(cors.Options{
51+
AllowedOrigins: []string{"*"},
52+
AllowedMethods: []string{
53+
http.MethodGet,
54+
http.MethodPost,
55+
http.MethodDelete,
56+
http.MethodOptions,
57+
},
58+
AllowedHeaders: []string{
59+
"Content-Type",
60+
"Accept",
61+
"Authorization",
62+
"x-api-key",
63+
"Mcp-Session-Id",
64+
"MCP-Protocol-Version",
65+
"Last-Event-ID",
66+
},
67+
ExposedHeaders: []string{
68+
"Content-Type",
69+
"Authorization",
70+
"x-api-key",
71+
"Mcp-Session-Id",
72+
"MCP-Protocol-Version",
73+
},
74+
MaxAge: 86400, // 24 hours in seconds
75+
AllowCredentials: false,
76+
})
77+
78+
return func(next http.Handler) http.Handler {
79+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
80+
// Use permissive CORS for OAuth2, MCP, and well-known endpoints
81+
if strings.HasPrefix(r.URL.Path, "/oauth2/") ||
82+
strings.HasPrefix(r.URL.Path, "/api/experimental/mcp/") ||
83+
strings.HasPrefix(r.URL.Path, "/.well-known/oauth-") {
84+
permissiveCors(next).ServeHTTP(w, r)
85+
return
86+
}
87+
88+
// Use standard CORS for all other endpoints
89+
standardCors(next).ServeHTTP(w, r)
90+
})
91+
}
4592
}
4693

4794
func WorkspaceAppCors(regex *regexp.Regexp, app appurl.ApplicationURL) func(next http.Handler) http.Handler {

coderd/httpmw/csp_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func TestCSP(t *testing.T) {
3434

3535
expected := []string{
3636
"frame-src 'self' *.test.com *.coder.com *.coder2.com",
37-
"media-src 'self' media.com media2.com",
37+
"media-src 'self' " + strings.Join(expectedMedia, " "),
3838
strings.Join([]string{
3939
"connect-src", "'self'",
4040
// Added from host header.

coderd/httpmw/httpmw_internal_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ func TestExtractExpectedAudience(t *testing.T) {
258258
}
259259
req.Host = tc.host
260260

261-
result := extractExpectedAudience(req)
261+
result := extractExpectedAudience(nil, req)
262262
assert.Equal(t, tc.expected, result)
263263
})
264264
}

coderd/oauth2provider/authorize.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizePar
3333
p := httpapi.NewQueryParamParser()
3434
vals := r.URL.Query()
3535

36-
p.RequiredNotEmpty("state", "response_type", "client_id")
36+
p.RequiredNotEmpty("response_type", "client_id")
3737

3838
params := authorizeParams{
3939
clientID: p.String(vals, "", "client_id"),
@@ -154,7 +154,9 @@ func ProcessAuthorize(db database.Store, accessURL *url.URL) http.HandlerFunc {
154154

155155
newQuery := params.redirectURL.Query()
156156
newQuery.Add("code", code.Formatted)
157-
newQuery.Add("state", params.state)
157+
if params.state != "" {
158+
newQuery.Add("state", params.state)
159+
}
158160
params.redirectURL.RawQuery = newQuery.Encode()
159161

160162
http.Redirect(rw, r, params.redirectURL.String(), http.StatusTemporaryRedirect)

0 commit comments

Comments
 (0)