Skip to content

Commit dd9cb2f

Browse files
committed
chore: add OAuth2 device flow test scripts
Change-Id: Ic232851727e683ab3d8b7ce970c505588da2f827 Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent bdc12a1 commit dd9cb2f

File tree

4 files changed

+424
-8
lines changed

4 files changed

+424
-8
lines changed

scripts/oauth2/README.md

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,39 @@ export STATE="your-state"
102102
go run ./scripts/oauth2/oauth2-test-server.go
103103
```
104104

105+
### `test-device-flow.sh`
106+
107+
Tests the OAuth2 Device Authorization Flow (RFC 8628) using the golang.org/x/oauth2 library. This flow is designed for devices that either lack a web browser or have limited input capabilities.
108+
109+
Usage:
110+
111+
```bash
112+
# First set up an app
113+
eval $(./scripts/oauth2/setup-test-app.sh)
114+
115+
# Run the device flow test
116+
./scripts/oauth2/test-device-flow.sh
117+
```
118+
119+
Features:
120+
121+
- Implements the complete device authorization flow
122+
- Uses the `/x/oauth2` library for OAuth2 operations
123+
- Displays user code and verification URL
124+
- Automatically polls for token completion
125+
- Tests the access token with an API call
126+
- Colored output for better readability
127+
128+
### `oauth2-device-flow.go`
129+
130+
A Go program that implements the OAuth2 device authorization flow. Used internally by `test-device-flow.sh` but can also be run standalone:
131+
132+
```bash
133+
export CLIENT_ID="your-client-id"
134+
export CLIENT_SECRET="your-client-secret"
135+
go run ./scripts/oauth2/oauth2-device-flow.go
136+
```
137+
105138
## Example Workflow
106139

107140
1. **Run automated tests:**
@@ -126,7 +159,23 @@ go run ./scripts/oauth2/oauth2-test-server.go
126159
./scripts/oauth2/cleanup-test-app.sh
127160
```
128161

129-
3. **Generate PKCE for custom testing:**
162+
3. **Device authorization flow testing:**
163+
164+
```bash
165+
# Create app
166+
eval $(./scripts/oauth2/setup-test-app.sh)
167+
168+
# Run the device flow test
169+
./scripts/oauth2/test-device-flow.sh
170+
# - Shows device code and verification URL
171+
# - Polls for authorization completion
172+
# - Tests access token
173+
174+
# Clean up when done
175+
./scripts/oauth2/cleanup-test-app.sh
176+
```
177+
178+
4. **Generate PKCE for custom testing:**
130179

131180
```bash
132181
./scripts/oauth2/generate-pkce.sh
@@ -147,4 +196,5 @@ All scripts respect these environment variables:
147196
- Metadata: `GET /.well-known/oauth-authorization-server`
148197
- Authorization: `GET/POST /oauth2/authorize`
149198
- Token: `POST /oauth2/token`
199+
- Device Authorization: `POST /oauth2/device`
150200
- Apps API: `/api/v2/oauth2-provider/apps`

scripts/oauth2/device/server.go

Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"log"
8+
"net/http"
9+
"net/url"
10+
"os"
11+
"strings"
12+
"time"
13+
14+
"golang.org/x/oauth2"
15+
"golang.org/x/xerrors"
16+
)
17+
18+
const (
19+
// ANSI color codes
20+
colorReset = "\033[0m"
21+
colorRed = "\033[31m"
22+
colorGreen = "\033[32m"
23+
colorYellow = "\033[33m"
24+
colorBlue = "\033[34m"
25+
colorPurple = "\033[35m"
26+
colorCyan = "\033[36m"
27+
colorWhite = "\033[37m"
28+
)
29+
30+
type DeviceCodeResponse struct {
31+
DeviceCode string `json:"device_code"`
32+
UserCode string `json:"user_code"`
33+
VerificationURI string `json:"verification_uri"`
34+
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
35+
ExpiresIn int `json:"expires_in"`
36+
Interval int `json:"interval"`
37+
}
38+
39+
type TokenResponse struct {
40+
AccessToken string `json:"access_token"`
41+
TokenType string `json:"token_type"`
42+
ExpiresIn int `json:"expires_in"`
43+
RefreshToken string `json:"refresh_token,omitempty"`
44+
Scope string `json:"scope,omitempty"`
45+
}
46+
47+
type ErrorResponse struct {
48+
Error string `json:"error"`
49+
ErrorDescription string `json:"error_description,omitempty"`
50+
}
51+
52+
type Config struct {
53+
ClientID string
54+
ClientSecret string
55+
BaseURL string
56+
}
57+
58+
func main() {
59+
config := &Config{
60+
ClientID: os.Getenv("CLIENT_ID"),
61+
ClientSecret: os.Getenv("CLIENT_SECRET"),
62+
BaseURL: getEnvOrDefault("BASE_URL", "http://localhost:3000"),
63+
}
64+
65+
if config.ClientID == "" || config.ClientSecret == "" {
66+
log.Fatal("CLIENT_ID and CLIENT_SECRET must be set. Run: eval $(./setup-test-app.sh) first")
67+
}
68+
69+
ctx := context.Background()
70+
71+
// Step 1: Request device code
72+
_, _ = fmt.Printf("%s=== Step 1: Device Code Request ===%s\n", colorBlue, colorReset)
73+
deviceResp, err := requestDeviceCode(ctx, config)
74+
if err != nil {
75+
log.Fatalf("Failed to get device code: %v", err)
76+
}
77+
78+
_, _ = fmt.Printf("%sDevice Code Response:%s\n", colorGreen, colorReset)
79+
prettyJSON, _ := json.MarshalIndent(deviceResp, "", " ")
80+
_, _ = fmt.Printf("%s\n", prettyJSON)
81+
_, _ = fmt.Println()
82+
83+
// Step 2: Display user instructions
84+
_, _ = fmt.Printf("%s=== Step 2: User Authorization ===%s\n", colorYellow, colorReset)
85+
_, _ = fmt.Printf("Please visit: %s%s%s\n", colorCyan, deviceResp.VerificationURI, colorReset)
86+
_, _ = fmt.Printf("Enter code: %s%s%s\n", colorPurple, deviceResp.UserCode, colorReset)
87+
_, _ = fmt.Println()
88+
89+
if deviceResp.VerificationURIComplete != "" {
90+
_, _ = fmt.Printf("Or visit the complete URL: %s%s%s\n", colorCyan, deviceResp.VerificationURIComplete, colorReset)
91+
_, _ = fmt.Println()
92+
}
93+
94+
_, _ = fmt.Printf("Waiting for authorization (expires in %d seconds)...\n", deviceResp.ExpiresIn)
95+
_, _ = fmt.Printf("Polling every %d seconds...\n", deviceResp.Interval)
96+
_, _ = fmt.Println()
97+
98+
// Step 3: Poll for token
99+
_, _ = fmt.Printf("%s=== Step 3: Token Polling ===%s\n", colorBlue, colorReset)
100+
tokenResp, err := pollForToken(ctx, config, deviceResp)
101+
if err != nil {
102+
log.Fatalf("Failed to get access token: %v", err)
103+
}
104+
105+
_, _ = fmt.Printf("%s=== Authorization Successful! ===%s\n", colorGreen, colorReset)
106+
_, _ = fmt.Printf("%sAccess Token Response:%s\n", colorGreen, colorReset)
107+
prettyTokenJSON, _ := json.MarshalIndent(tokenResp, "", " ")
108+
_, _ = fmt.Printf("%s\n", prettyTokenJSON)
109+
_, _ = fmt.Println()
110+
111+
// Step 4: Test the access token
112+
_, _ = fmt.Printf("%s=== Step 4: Testing Access Token ===%s\n", colorBlue, colorReset)
113+
if err := testAccessToken(ctx, config, tokenResp.AccessToken); err != nil {
114+
log.Printf("%sWarning: Failed to test access token: %v%s", colorYellow, err, colorReset)
115+
} else {
116+
_, _ = fmt.Printf("%sAccess token is valid and working!%s\n", colorGreen, colorReset)
117+
}
118+
119+
_, _ = fmt.Println()
120+
_, _ = fmt.Printf("%sDevice authorization flow completed successfully!%s\n", colorGreen, colorReset)
121+
_, _ = fmt.Printf("You can now use the access token to make authenticated API requests.\n")
122+
}
123+
124+
func requestDeviceCode(ctx context.Context, config *Config) (*DeviceCodeResponse, error) {
125+
// Use x/oauth2 clientcredentials config to structure the request
126+
// clientConfig := &clientcredentials.Config{
127+
// ClientID: config.ClientID,
128+
// ClientSecret: config.ClientSecret,
129+
// TokenURL: config.BaseURL + "/oauth2/device", // Device code endpoint (RFC 8628)
130+
// }
131+
132+
// Create form data for device code request
133+
data := url.Values{}
134+
data.Set("client_id", config.ClientID)
135+
136+
// Optional: Add scope parameter
137+
// data.Set("scope", "openid profile")
138+
139+
// Make the request to the device authorization endpoint
140+
req, err := http.NewRequestWithContext(ctx, "POST", config.BaseURL+"/oauth2/device", strings.NewReader(data.Encode()))
141+
if err != nil {
142+
return nil, xerrors.Errorf("creating request: %w", err)
143+
}
144+
145+
// Set up basic auth with client credentials
146+
req.SetBasicAuth(config.ClientID, config.ClientSecret)
147+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
148+
149+
client := &http.Client{Timeout: 30 * time.Second}
150+
resp, err := client.Do(req)
151+
if err != nil {
152+
return nil, xerrors.Errorf("making request: %w", err)
153+
}
154+
defer func() { _ = resp.Body.Close() }()
155+
156+
if resp.StatusCode != http.StatusOK {
157+
var errResp ErrorResponse
158+
if err := json.NewDecoder(resp.Body).Decode(&errResp); err == nil {
159+
return nil, xerrors.Errorf("device code request failed: %s - %s", errResp.Error, errResp.ErrorDescription)
160+
}
161+
return nil, xerrors.Errorf("device code request failed with status %d", resp.StatusCode)
162+
}
163+
164+
var deviceResp DeviceCodeResponse
165+
if err := json.NewDecoder(resp.Body).Decode(&deviceResp); err != nil {
166+
return nil, xerrors.Errorf("decoding response: %w", err)
167+
}
168+
169+
return &deviceResp, nil
170+
}
171+
172+
func pollForToken(ctx context.Context, config *Config, deviceResp *DeviceCodeResponse) (*TokenResponse, error) {
173+
// Use x/oauth2 config for token exchange
174+
oauth2Config := &oauth2.Config{
175+
ClientID: config.ClientID,
176+
ClientSecret: config.ClientSecret,
177+
Endpoint: oauth2.Endpoint{
178+
TokenURL: config.BaseURL + "/oauth2/token",
179+
},
180+
}
181+
182+
interval := time.Duration(deviceResp.Interval) * time.Second
183+
if interval < 5*time.Second {
184+
interval = 5 * time.Second // Minimum polling interval
185+
}
186+
187+
deadline := time.Now().Add(time.Duration(deviceResp.ExpiresIn) * time.Second)
188+
ticker := time.NewTicker(interval)
189+
defer ticker.Stop()
190+
191+
for {
192+
select {
193+
case <-ctx.Done():
194+
return nil, ctx.Err()
195+
case <-ticker.C:
196+
if time.Now().After(deadline) {
197+
return nil, xerrors.New("device code expired")
198+
}
199+
200+
_, _ = fmt.Printf("Polling for token...\n")
201+
202+
// Create token exchange request using device_code grant
203+
data := url.Values{}
204+
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
205+
data.Set("device_code", deviceResp.DeviceCode)
206+
data.Set("client_id", config.ClientID)
207+
208+
req, err := http.NewRequestWithContext(ctx, "POST", oauth2Config.Endpoint.TokenURL, strings.NewReader(data.Encode()))
209+
if err != nil {
210+
return nil, xerrors.Errorf("creating token request: %w", err)
211+
}
212+
213+
req.SetBasicAuth(config.ClientID, config.ClientSecret)
214+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
215+
216+
client := &http.Client{Timeout: 30 * time.Second}
217+
resp, err := client.Do(req)
218+
if err != nil {
219+
_, _ = fmt.Printf("Request error: %v\n", err)
220+
continue
221+
}
222+
223+
var result map[string]interface{}
224+
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
225+
_ = resp.Body.Close()
226+
_, _ = fmt.Printf("Decode error: %v\n", err)
227+
continue
228+
}
229+
_ = resp.Body.Close()
230+
231+
if errorCode, ok := result["error"].(string); ok {
232+
switch errorCode {
233+
case "authorization_pending":
234+
_, _ = fmt.Printf("Authorization pending... continuing to poll\n")
235+
continue
236+
case "slow_down":
237+
_, _ = fmt.Printf("Slow down request - increasing polling interval by 5 seconds\n")
238+
interval += 5 * time.Second
239+
ticker.Reset(interval)
240+
continue
241+
case "access_denied":
242+
return nil, xerrors.New("access denied by user")
243+
case "expired_token":
244+
return nil, xerrors.New("device code expired")
245+
default:
246+
desc := ""
247+
if errorDesc, ok := result["error_description"].(string); ok {
248+
desc = " - " + errorDesc
249+
}
250+
return nil, xerrors.Errorf("token error: %s%s", errorCode, desc)
251+
}
252+
}
253+
254+
// Success case - convert to TokenResponse
255+
var tokenResp TokenResponse
256+
if accessToken, ok := result["access_token"].(string); ok {
257+
tokenResp.AccessToken = accessToken
258+
}
259+
if tokenType, ok := result["token_type"].(string); ok {
260+
tokenResp.TokenType = tokenType
261+
}
262+
if expiresIn, ok := result["expires_in"].(float64); ok {
263+
tokenResp.ExpiresIn = int(expiresIn)
264+
}
265+
if refreshToken, ok := result["refresh_token"].(string); ok {
266+
tokenResp.RefreshToken = refreshToken
267+
}
268+
if scope, ok := result["scope"].(string); ok {
269+
tokenResp.Scope = scope
270+
}
271+
272+
if tokenResp.AccessToken == "" {
273+
return nil, xerrors.New("no access token in response")
274+
}
275+
276+
return &tokenResp, nil
277+
}
278+
}
279+
}
280+
281+
func testAccessToken(ctx context.Context, config *Config, accessToken string) error {
282+
req, err := http.NewRequestWithContext(ctx, "GET", config.BaseURL+"/api/v2/users/me", nil)
283+
if err != nil {
284+
return xerrors.Errorf("creating request: %w", err)
285+
}
286+
287+
req.Header.Set("Coder-Session-Token", accessToken)
288+
289+
client := &http.Client{Timeout: 10 * time.Second}
290+
resp, err := client.Do(req)
291+
if err != nil {
292+
return xerrors.Errorf("making request: %w", err)
293+
}
294+
defer func() { _ = resp.Body.Close() }()
295+
296+
if resp.StatusCode != http.StatusOK {
297+
return xerrors.Errorf("API request failed with status %d", resp.StatusCode)
298+
}
299+
300+
var userInfo map[string]interface{}
301+
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
302+
return xerrors.Errorf("decoding response: %w", err)
303+
}
304+
305+
_, _ = fmt.Printf("%sAPI Test Response:%s\n", colorGreen, colorReset)
306+
prettyJSON, _ := json.MarshalIndent(userInfo, "", " ")
307+
_, _ = fmt.Printf("%s\n", prettyJSON)
308+
309+
return nil
310+
}
311+
312+
func getEnvOrDefault(key, defaultValue string) string {
313+
if value := os.Getenv(key); value != "" {
314+
return value
315+
}
316+
return defaultValue
317+
}

0 commit comments

Comments
 (0)