Skip to content

Commit bd01a1d

Browse files
committed
feat: enhance OAuth2 RFC compliance with resource metadata and CORS improvements
Change-Id: I99fc71255165133bf858268030d39d2b1a71a288 Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent f47efc6 commit bd01a1d

File tree

4 files changed

+114
-37
lines changed

4 files changed

+114
-37
lines changed

coderd/coderd.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,7 @@ func New(options *Options) *API {
790790
SessionTokenFunc: nil, // Default behavior
791791
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
792792
Logger: options.Logger,
793+
AccessURL: options.AccessURL,
793794
})
794795
// Same as above but it redirects to the login page.
795796
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
@@ -801,6 +802,7 @@ func New(options *Options) *API {
801802
SessionTokenFunc: nil, // Default behavior
802803
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
803804
Logger: options.Logger,
805+
AccessURL: options.AccessURL,
804806
})
805807
// Same as the first but it's optional.
806808
apiKeyMiddlewareOptional := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
@@ -812,6 +814,7 @@ func New(options *Options) *API {
812814
SessionTokenFunc: nil, // Default behavior
813815
PostAuthAdditionalHeadersFunc: options.PostAuthAdditionalHeadersFunc,
814816
Logger: options.Logger,
817+
AccessURL: options.AccessURL,
815818
})
816819

817820
workspaceAgentInfo := httpmw.ExtractWorkspaceAgentAndLatestBuild(httpmw.ExtractWorkspaceAgentAndLatestBuildConfig{

coderd/httpmw/apikey.go

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ type ExtractAPIKeyConfig struct {
113113
// a user is authenticated to prevent additional CLI invocations.
114114
PostAuthAdditionalHeadersFunc func(a rbac.Subject, header http.Header)
115115

116+
// AccessURL is the configured access URL for this Coder deployment.
117+
// Used for generating OAuth2 resource metadata URLs in WWW-Authenticate headers.
118+
AccessURL *url.URL
119+
116120
// Logger is used for logging middleware operations.
117121
Logger slog.Logger
118122
}
@@ -214,29 +218,9 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
214218
return nil, nil, false
215219
}
216220

217-
// Add WWW-Authenticate header for 401/403 responses (RFC 6750)
221+
// Add WWW-Authenticate header for 401/403 responses (RFC 6750 + RFC 9728)
218222
if code == http.StatusUnauthorized || code == http.StatusForbidden {
219-
var wwwAuth string
220-
221-
switch code {
222-
case http.StatusUnauthorized:
223-
// Map 401 to invalid_token with specific error descriptions
224-
switch {
225-
case strings.Contains(response.Message, "expired") || strings.Contains(response.Detail, "expired"):
226-
wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token has expired"`
227-
case strings.Contains(response.Message, "audience") || strings.Contains(response.Message, "mismatch"):
228-
wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token audience does not match this resource"`
229-
default:
230-
wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token is invalid"`
231-
}
232-
case http.StatusForbidden:
233-
// Map 403 to insufficient_scope per RFC 6750
234-
wwwAuth = `Bearer realm="coder", error="insufficient_scope", error_description="The request requires higher privileges than provided by the access token"`
235-
default:
236-
wwwAuth = `Bearer realm="coder"`
237-
}
238-
239-
rw.Header().Set("WWW-Authenticate", wwwAuth)
223+
rw.Header().Set("WWW-Authenticate", buildWWWAuthenticateHeader(cfg.AccessURL, r, code, response))
240224
}
241225

242226
httpapi.Write(ctx, rw, code, response)
@@ -272,7 +256,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
272256

273257
// Validate OAuth2 provider app token audience (RFC 8707) if applicable
274258
if key.LoginType == database.LoginTypeOAuth2ProviderApp {
275-
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, r); err != nil {
259+
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, cfg.AccessURL, r); err != nil {
276260
// Log the detailed error for debugging but don't expose it to the client
277261
cfg.Logger.Debug(ctx, "oauth2 token audience validation failed", slog.Error(err))
278262
return optionalWrite(http.StatusForbidden, codersdk.Response{
@@ -489,7 +473,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
489473

490474
// validateOAuth2ProviderAppTokenAudience validates that an OAuth2 provider app token
491475
// is being used with the correct audience/resource server (RFC 8707).
492-
func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Store, key database.APIKey, r *http.Request) error {
476+
func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Store, key database.APIKey, accessURL *url.URL, r *http.Request) error {
493477
// Get the OAuth2 provider app token to check its audience
494478
//nolint:gocritic // System needs to access token for audience validation
495479
token, err := db.GetOAuth2ProviderAppTokenByAPIKeyID(dbauthz.AsSystemRestricted(ctx), key.ID)
@@ -502,8 +486,8 @@ func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Sto
502486
return nil
503487
}
504488

505-
// Extract the expected audience from the request
506-
expectedAudience := extractExpectedAudience(r)
489+
// Extract the expected audience from the access URL
490+
expectedAudience := extractExpectedAudience(accessURL, r)
507491

508492
// Normalize both audience values for RFC 3986 compliant comparison
509493
normalizedTokenAudience := normalizeAudienceURI(token.Audience.String)
@@ -624,18 +608,59 @@ func normalizePathSegments(path string) string {
624608

625609
// Test export functions for testing package access
626610

611+
// buildWWWAuthenticateHeader constructs RFC 6750 + RFC 9728 compliant WWW-Authenticate header
612+
func buildWWWAuthenticateHeader(accessURL *url.URL, r *http.Request, code int, response codersdk.Response) string {
613+
// Use the configured access URL for resource metadata
614+
if accessURL == nil {
615+
scheme := "https"
616+
if r.TLS == nil {
617+
scheme = "http"
618+
}
619+
620+
// Use the Host header to construct the canonical audience URI
621+
accessURL = &url.URL{
622+
Scheme: scheme,
623+
Host: r.Host,
624+
}
625+
}
626+
627+
resourceMetadata := accessURL.JoinPath("/.well-known/oauth-protected-resource").String()
628+
629+
switch code {
630+
case http.StatusUnauthorized:
631+
switch {
632+
case strings.Contains(response.Message, "expired") || strings.Contains(response.Detail, "expired"):
633+
return fmt.Sprintf(`Bearer realm="coder", error="invalid_token", error_description="The access token has expired", resource_metadata="%s"`, resourceMetadata)
634+
case strings.Contains(response.Message, "audience") || strings.Contains(response.Message, "mismatch"):
635+
return fmt.Sprintf(`Bearer realm="coder", error="invalid_token", error_description="The access token audience does not match this resource", resource_metadata="%s"`, resourceMetadata)
636+
default:
637+
return fmt.Sprintf(`Bearer realm="coder", error="invalid_token", error_description="The access token is invalid", resource_metadata="%s"`, resourceMetadata)
638+
}
639+
case http.StatusForbidden:
640+
return fmt.Sprintf(`Bearer realm="coder", error="insufficient_scope", error_description="The request requires higher privileges than provided by the access token", resource_metadata="%s"`, resourceMetadata)
641+
default:
642+
return fmt.Sprintf(`Bearer realm="coder", resource_metadata="%s"`, resourceMetadata)
643+
}
644+
}
645+
627646
// extractExpectedAudience determines the expected audience for the current request.
628647
// This should match the resource parameter used during authorization.
629-
func extractExpectedAudience(r *http.Request) string {
648+
func extractExpectedAudience(accessURL *url.URL, r *http.Request) string {
630649
// For MCP compliance, the audience should be the canonical URI of the resource server
631650
// This typically matches the access URL of the Coder deployment
632-
scheme := "https"
633-
if r.TLS == nil {
634-
scheme = "http"
635-
}
651+
var audience string
652+
653+
if accessURL != nil {
654+
audience = accessURL.String()
655+
} else {
656+
scheme := "https"
657+
if r.TLS == nil {
658+
scheme = "http"
659+
}
636660

637-
// Use the Host header to construct the canonical audience URI
638-
audience := fmt.Sprintf("%s://%s", scheme, r.Host)
661+
// Use the Host header to construct the canonical audience URI
662+
audience = fmt.Sprintf("%s://%s", scheme, r.Host)
663+
}
639664

640665
// Normalize the URI according to RFC 3986 for consistent comparison
641666
return normalizeAudienceURI(audience)

coderd/httpmw/cors.go

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"net/http"
55
"net/url"
66
"regexp"
7+
"strings"
78

89
"github.com/go-chi/cors"
910

@@ -28,20 +29,66 @@ const (
2829
func Cors(allowAll bool, origins ...string) func(next http.Handler) http.Handler {
2930
if len(origins) == 0 {
3031
// The default behavior is '*', so putting the empty string defaults to
31-
// the secure behavior of blocking CORs requests.
32+
// the secure behavior of blocking CORS requests.
3233
origins = []string{""}
3334
}
3435
if allowAll {
3536
origins = []string{"*"}
3637
}
37-
return cors.Handler(cors.Options{
38+
39+
// Standard CORS for most endpoints
40+
standardCors := cors.Handler(cors.Options{
3841
AllowedOrigins: origins,
3942
// We only need GET for latency requests
4043
AllowedMethods: []string{http.MethodOptions, http.MethodGet},
4144
AllowedHeaders: []string{"Accept", "Content-Type", "X-LATENCY-CHECK", "X-CSRF-TOKEN"},
4245
// Do not send any cookies
4346
AllowCredentials: false,
4447
})
48+
49+
// Permissive CORS for OAuth2 and MCP endpoints
50+
permissiveCors := cors.Handler(cors.Options{
51+
AllowedOrigins: []string{"*"},
52+
AllowedMethods: []string{
53+
http.MethodGet,
54+
http.MethodPost,
55+
http.MethodDelete,
56+
http.MethodOptions,
57+
},
58+
AllowedHeaders: []string{
59+
"Content-Type",
60+
"Accept",
61+
"Authorization",
62+
"x-api-key",
63+
"Mcp-Session-Id",
64+
"MCP-Protocol-Version",
65+
"Last-Event-ID",
66+
},
67+
ExposedHeaders: []string{
68+
"Content-Type",
69+
"Authorization",
70+
"x-api-key",
71+
"Mcp-Session-Id",
72+
"MCP-Protocol-Version",
73+
},
74+
MaxAge: 86400, // 24 hours in seconds
75+
AllowCredentials: false,
76+
})
77+
78+
return func(next http.Handler) http.Handler {
79+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
80+
// Use permissive CORS for OAuth2, MCP, and well-known endpoints
81+
if strings.HasPrefix(r.URL.Path, "/oauth2/") ||
82+
strings.HasPrefix(r.URL.Path, "/api/experimental/mcp/") ||
83+
strings.HasPrefix(r.URL.Path, "/.well-known/oauth-") {
84+
permissiveCors(next).ServeHTTP(w, r)
85+
return
86+
}
87+
88+
// Use standard CORS for all other endpoints
89+
standardCors(next).ServeHTTP(w, r)
90+
})
91+
}
4592
}
4693

4794
func WorkspaceAppCors(regex *regexp.Regexp, app appurl.ApplicationURL) func(next http.Handler) http.Handler {

coderd/oauth2provider/authorize.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func extractAuthorizeParams(r *http.Request, callbackURL *url.URL) (authorizePar
3333
p := httpapi.NewQueryParamParser()
3434
vals := r.URL.Query()
3535

36-
p.RequiredNotEmpty("state", "response_type", "client_id")
36+
p.RequiredNotEmpty("response_type", "client_id")
3737

3838
params := authorizeParams{
3939
clientID: p.String(vals, "", "client_id"),
@@ -154,7 +154,9 @@ func ProcessAuthorize(db database.Store, accessURL *url.URL) http.HandlerFunc {
154154

155155
newQuery := params.redirectURL.Query()
156156
newQuery.Add("code", code.Formatted)
157-
newQuery.Add("state", params.state)
157+
if params.state != "" {
158+
newQuery.Add("state", params.state)
159+
}
158160
params.redirectURL.RawQuery = newQuery.Encode()
159161

160162
http.Redirect(rw, r, params.redirectURL.String(), http.StatusTemporaryRedirect)

0 commit comments

Comments
 (0)