Skip to content

feat: add RFC 9728 OAuth2 resource metadata support #18920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,7 @@ func New(options *Options) *API {
SessionTokenFunc: nil, // Default behavior
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
Logger: options.Logger,
AccessURL: options.AccessURL,
})
// Same as above but it redirects to the login page.
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
Expand All @@ -801,6 +802,7 @@ func New(options *Options) *API {
SessionTokenFunc: nil, // Default behavior
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
Logger: options.Logger,
AccessURL: options.AccessURL,
})
// Same as the first but it's optional.
apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
Expand All @@ -812,6 +814,7 @@ func New(options *Options) *API {
SessionTokenFunc: nil, // Default behavior
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
Logger: options.Logger,
AccessURL: options.AccessURL,
})

workspaceAgentInfo := httpmw.ExtractWorkspaceAgentAndLatestBuild(httpmw.ExtractWorkspaceAgentAndLatestBuildConfig{
Expand Down
91 changes: 58 additions & 33 deletions coderd/httpmw/apikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ type ExtractAPIKeyConfig struct {
// a user is authenticated to prevent additional CLI invocations.
PostAuthAdditionalHeadersFunc func(a rbac.Subject, header http.Header)

// AccessURL is the configured access URL for this Coder deployment.
// Used for generating OAuth2 resource metadata URLs in WWW-Authenticate headers.
AccessURL *url.URL

// Logger is used for logging middleware operations.
Logger slog.Logger
}
Expand Down Expand Up @@ -214,29 +218,9 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
return nil, nil, false
}

// Add WWW-Authenticate header for 401/403 responses (RFC 6750)
// Add WWW-Authenticate header for 401/403 responses (RFC 6750 + RFC 9728)
if code == http.StatusUnauthorized || code == http.StatusForbidden {
var wwwAuth string

switch code {
case http.StatusUnauthorized:
// Map 401 to invalid_token with specific error descriptions
switch {
case strings.Contains(response.Message, "expired") || strings.Contains(response.Detail, "expired"):
wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token has expired"`
case strings.Contains(response.Message, "audience") || strings.Contains(response.Message, "mismatch"):
wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token audience does not match this resource"`
default:
wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token is invalid"`
}
case http.StatusForbidden:
// Map 403 to insufficient_scope per RFC 6750
wwwAuth = `Bearer realm="coder", error="insufficient_scope", error_description="The request requires higher privileges than provided by the access token"`
default:
wwwAuth = `Bearer realm="coder"`
}

rw.Header().Set("WWW-Authenticate", wwwAuth)
rw.Header().Set("WWW-Authenticate", buildWWWAuthenticateHeader(cfg.AccessURL, r, code, response))
}

httpapi.Write(ctx, rw, code, response)
Expand Down Expand Up @@ -272,7 +256,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon

// Validate OAuth2 provider app token audience (RFC 8707) if applicable
if key.LoginType == database.LoginTypeOAuth2ProviderApp {
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, r); err != nil {
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, cfg.AccessURL, r); err != nil {
// Log the detailed error for debugging but don't expose it to the client
cfg.Logger.Debug(ctx, "oauth2 token audience validation failed", slog.Error(err))
return optionalWrite(http.StatusForbidden, codersdk.Response{
Expand Down Expand Up @@ -489,7 +473,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon

// validateOAuth2ProviderAppTokenAudience validates that an OAuth2 provider app token
// is being used with the correct audience/resource server (RFC 8707).
func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Store, key database.APIKey, r *http.Request) error {
func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Store, key database.APIKey, accessURL *url.URL, r *http.Request) error {
// Get the OAuth2 provider app token to check its audience
//nolint:gocritic // System needs to access token for audience validation
token, err := db.GetOAuth2ProviderAppTokenByAPIKeyID(dbauthz.AsSystemRestricted(ctx), key.ID)
Expand All @@ -502,8 +486,8 @@ func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Sto
return nil
}

// Extract the expected audience from the request
expectedAudience := extractExpectedAudience(r)
// Extract the expected audience from the access URL
expectedAudience := extractExpectedAudience(accessURL, r)

// Normalize both audience values for RFC 3986 compliant comparison
normalizedTokenAudience := normalizeAudienceURI(token.Audience.String)
Expand Down Expand Up @@ -624,18 +608,59 @@ func normalizePathSegments(path string) string {

// Test export functions for testing package access

// buildWWWAuthenticateHeader constructs RFC 6750 + RFC 9728 compliant WWW-Authenticate header
func buildWWWAuthenticateHeader(accessURL *url.URL, r *http.Request, code int, response codersdk.Response) string {
// Use the configured access URL for resource metadata
if accessURL == nil {
scheme := "https"
if r.TLS == nil {
scheme = "http"
}

// Use the Host header to construct the canonical audience URI
accessURL = &url.URL{
Scheme: scheme,
Host: r.Host,
}
}

resourceMetadata := accessURL.JoinPath("/.well-known/oauth-protected-resource").String()

switch code {
case http.StatusUnauthorized:
switch {
case strings.Contains(response.Message, "expired") || strings.Contains(response.Detail, "expired"):
return fmt.Sprintf(`Bearer realm="coder", error="invalid_token", error_description="The access token has expired", resource_metadata=%q`, resourceMetadata)
case strings.Contains(response.Message, "audience") || strings.Contains(response.Message, "mismatch"):
return fmt.Sprintf(`Bearer realm="coder", error="invalid_token", error_description="The access token audience does not match this resource", resource_metadata=%q`, resourceMetadata)
default:
return fmt.Sprintf(`Bearer realm="coder", error="invalid_token", error_description="The access token is invalid", resource_metadata=%q`, resourceMetadata)
}
case http.StatusForbidden:
return fmt.Sprintf(`Bearer realm="coder", error="insufficient_scope", error_description="The request requires higher privileges than provided by the access token", resource_metadata=%q`, resourceMetadata)
default:
return fmt.Sprintf(`Bearer realm="coder", resource_metadata=%q`, resourceMetadata)
}
}

// extractExpectedAudience determines the expected audience for the current request.
// This should match the resource parameter used during authorization.
func extractExpectedAudience(r *http.Request) string {
func extractExpectedAudience(accessURL *url.URL, r *http.Request) string {
// For MCP compliance, the audience should be the canonical URI of the resource server
// This typically matches the access URL of the Coder deployment
scheme := "https"
if r.TLS == nil {
scheme = "http"
}
var audience string

if accessURL != nil {
audience = accessURL.String()
} else {
scheme := "https"
if r.TLS == nil {
scheme = "http"
}

// Use the Host header to construct the canonical audience URI
audience := fmt.Sprintf("%s://%s", scheme, r.Host)
// Use the Host header to construct the canonical audience URI
audience = fmt.Sprintf("%s://%s", scheme, r.Host)
}

// Normalize the URI according to RFC 3986 for consistent comparison
return normalizeAudienceURI(audience)
Expand Down
51 changes: 49 additions & 2 deletions coderd/httpmw/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"net/http"
"net/url"
"regexp"
"strings"

"github.com/go-chi/cors"

Expand All @@ -28,20 +29,66 @@ const (
func Cors(allowAll bool, origins ...string) func(next http.Handler) http.Handler {
if len(origins) == 0 {
// The default behavior is '*', so putting the empty string defaults to
// the secure behavior of blocking CORs requests.
// the secure behavior of blocking CORS requests.
origins = []string{""}
}
if allowAll {
origins = []string{"*"}
}
return cors.Handler(cors.Options{

// Standard CORS for most endpoints
standardCors := cors.Handler(cors.Options{
AllowedOrigins: origins,
// We only need GET for latency requests
AllowedMethods: []string{http.MethodOptions, http.MethodGet},
AllowedHeaders: []string{"Accept", "Content-Type", "X-LATENCY-CHECK", "X-CSRF-TOKEN"},
// Do not send any cookies
AllowCredentials: false,
})

// Permissive CORS for OAuth2 and MCP endpoints
permissiveCors := cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{
http.MethodGet,
http.MethodPost,
http.MethodDelete,
http.MethodOptions,
},
AllowedHeaders: []string{
"Content-Type",
"Accept",
"Authorization",
"x-api-key",
"Mcp-Session-Id",
"MCP-Protocol-Version",
"Last-Event-ID",
},
ExposedHeaders: []string{
"Content-Type",
"Authorization",
"x-api-key",
"Mcp-Session-Id",
"MCP-Protocol-Version",
},
MaxAge: 86400, // 24 hours in seconds
AllowCredentials: false,
})

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Use permissive CORS for OAuth2, MCP, and well-known endpoints
if strings.HasPrefix(r.URL.Path, "/oauth2/") ||
strings.HasPrefix(r.URL.Path, "/api/experimental/mcp/") ||
strings.HasPrefix(r.URL.Path, "/.well-known/oauth-") {
permissiveCors(next).ServeHTTP(w, r)
return
}

// Use standard CORS for all other endpoints
standardCors(next).ServeHTTP(w, r)
})
}
}

func WorkspaceAppCors(regex *regexp.Regexp, app appurl.ApplicationURL) func(next http.Handler) http.Handler {
Expand Down
2 changes: 1 addition & 1 deletion coderd/httpmw/csp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestCSP(t *testing.T) {

expected := []string{
"frame-src 'self' *.test.com *.coder.com *.coder2.com",
"media-src 'self' media.com media2.com",
"media-src 'self' " + strings.Join(expectedMedia, " "),
strings.Join([]string{
"connect-src", "'self'",
// Added from host header.
Expand Down
2 changes: 1 addition & 1 deletion coderd/httpmw/httpmw_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ func TestExtractExpectedAudience(t *testing.T) {
}
req.Host = tc.host

result := extractExpectedAudience(req)
result := extractExpectedAudience(nil, req)
assert.Equal(t, tc.expected, result)
})
}
Expand Down
6 changes: 4 additions & 2 deletions coderd/oauth2provider/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizePar
p := httpapi.NewQueryParamParser()
vals := r.URL.Query()

p.RequiredNotEmpty("state", "response_type", "client_id")
p.RequiredNotEmpty("response_type", "client_id")

params := authorizeParams{
clientID: p.String(vals, "", "client_id"),
Expand Down Expand Up @@ -154,7 +154,9 @@ func ProcessAuthorize(db database.Store, accessURL *url.URL) http.HandlerFunc {

newQuery := params.redirectURL.Query()
newQuery.Add("code", code.Formatted)
newQuery.Add("state", params.state)
if params.state != "" {
newQuery.Add("state", params.state)
}
params.redirectURL.RawQuery = newQuery.Encode()

http.Redirect(rw, r, params.redirectURL.String(), http.StatusTemporaryRedirect)
Expand Down
Loading