|
| 1 | +package main |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "encoding/json" |
| 6 | + "fmt" |
| 7 | + "log" |
| 8 | + "net/http" |
| 9 | + "net/url" |
| 10 | + "os" |
| 11 | + "strings" |
| 12 | + "time" |
| 13 | + |
| 14 | + "golang.org/x/oauth2" |
| 15 | + "golang.org/x/xerrors" |
| 16 | +) |
| 17 | + |
| 18 | +const ( |
| 19 | + // ANSI color codes |
| 20 | + colorReset = "\033[0m" |
| 21 | + colorRed = "\033[31m" |
| 22 | + colorGreen = "\033[32m" |
| 23 | + colorYellow = "\033[33m" |
| 24 | + colorBlue = "\033[34m" |
| 25 | + colorPurple = "\033[35m" |
| 26 | + colorCyan = "\033[36m" |
| 27 | + colorWhite = "\033[37m" |
| 28 | +) |
| 29 | + |
| 30 | +type DeviceCodeResponse struct { |
| 31 | + DeviceCode string `json:"device_code"` |
| 32 | + UserCode string `json:"user_code"` |
| 33 | + VerificationURI string `json:"verification_uri"` |
| 34 | + VerificationURIComplete string `json:"verification_uri_complete,omitempty"` |
| 35 | + ExpiresIn int `json:"expires_in"` |
| 36 | + Interval int `json:"interval"` |
| 37 | +} |
| 38 | + |
| 39 | +type TokenResponse struct { |
| 40 | + AccessToken string `json:"access_token"` |
| 41 | + TokenType string `json:"token_type"` |
| 42 | + ExpiresIn int `json:"expires_in"` |
| 43 | + RefreshToken string `json:"refresh_token,omitempty"` |
| 44 | + Scope string `json:"scope,omitempty"` |
| 45 | +} |
| 46 | + |
| 47 | +type ErrorResponse struct { |
| 48 | + Error string `json:"error"` |
| 49 | + ErrorDescription string `json:"error_description,omitempty"` |
| 50 | +} |
| 51 | + |
| 52 | +type Config struct { |
| 53 | + ClientID string |
| 54 | + ClientSecret string |
| 55 | + BaseURL string |
| 56 | +} |
| 57 | + |
| 58 | +func main() { |
| 59 | + config := &Config{ |
| 60 | + ClientID: os.Getenv("CLIENT_ID"), |
| 61 | + ClientSecret: os.Getenv("CLIENT_SECRET"), |
| 62 | + BaseURL: getEnvOrDefault("BASE_URL", "http://localhost:3000"), |
| 63 | + } |
| 64 | + |
| 65 | + if config.ClientID == "" || config.ClientSecret == "" { |
| 66 | + log.Fatal("CLIENT_ID and CLIENT_SECRET must be set. Run: eval $(./setup-test-app.sh) first") |
| 67 | + } |
| 68 | + |
| 69 | + ctx := context.Background() |
| 70 | + |
| 71 | + // Step 1: Request device code |
| 72 | + _, _ = fmt.Printf("%s=== Step 1: Device Code Request ===%s\n", colorBlue, colorReset) |
| 73 | + deviceResp, err := requestDeviceCode(ctx, config) |
| 74 | + if err != nil { |
| 75 | + log.Fatalf("Failed to get device code: %v", err) |
| 76 | + } |
| 77 | + |
| 78 | + _, _ = fmt.Printf("%sDevice Code Response:%s\n", colorGreen, colorReset) |
| 79 | + prettyJSON, _ := json.MarshalIndent(deviceResp, "", " ") |
| 80 | + _, _ = fmt.Printf("%s\n", prettyJSON) |
| 81 | + _, _ = fmt.Println() |
| 82 | + |
| 83 | + // Step 2: Display user instructions |
| 84 | + _, _ = fmt.Printf("%s=== Step 2: User Authorization ===%s\n", colorYellow, colorReset) |
| 85 | + _, _ = fmt.Printf("Please visit: %s%s%s\n", colorCyan, deviceResp.VerificationURI, colorReset) |
| 86 | + _, _ = fmt.Printf("Enter code: %s%s%s\n", colorPurple, deviceResp.UserCode, colorReset) |
| 87 | + _, _ = fmt.Println() |
| 88 | + |
| 89 | + if deviceResp.VerificationURIComplete != "" { |
| 90 | + _, _ = fmt.Printf("Or visit the complete URL: %s%s%s\n", colorCyan, deviceResp.VerificationURIComplete, colorReset) |
| 91 | + _, _ = fmt.Println() |
| 92 | + } |
| 93 | + |
| 94 | + _, _ = fmt.Printf("Waiting for authorization (expires in %d seconds)...\n", deviceResp.ExpiresIn) |
| 95 | + _, _ = fmt.Printf("Polling every %d seconds...\n", deviceResp.Interval) |
| 96 | + _, _ = fmt.Println() |
| 97 | + |
| 98 | + // Step 3: Poll for token |
| 99 | + _, _ = fmt.Printf("%s=== Step 3: Token Polling ===%s\n", colorBlue, colorReset) |
| 100 | + tokenResp, err := pollForToken(ctx, config, deviceResp) |
| 101 | + if err != nil { |
| 102 | + log.Fatalf("Failed to get access token: %v", err) |
| 103 | + } |
| 104 | + |
| 105 | + _, _ = fmt.Printf("%s=== Authorization Successful! ===%s\n", colorGreen, colorReset) |
| 106 | + _, _ = fmt.Printf("%sAccess Token Response:%s\n", colorGreen, colorReset) |
| 107 | + prettyTokenJSON, _ := json.MarshalIndent(tokenResp, "", " ") |
| 108 | + _, _ = fmt.Printf("%s\n", prettyTokenJSON) |
| 109 | + _, _ = fmt.Println() |
| 110 | + |
| 111 | + // Step 4: Test the access token |
| 112 | + _, _ = fmt.Printf("%s=== Step 4: Testing Access Token ===%s\n", colorBlue, colorReset) |
| 113 | + if err := testAccessToken(ctx, config, tokenResp.AccessToken); err != nil { |
| 114 | + log.Printf("%sWarning: Failed to test access token: %v%s", colorYellow, err, colorReset) |
| 115 | + } else { |
| 116 | + _, _ = fmt.Printf("%sAccess token is valid and working!%s\n", colorGreen, colorReset) |
| 117 | + } |
| 118 | + |
| 119 | + _, _ = fmt.Println() |
| 120 | + _, _ = fmt.Printf("%sDevice authorization flow completed successfully!%s\n", colorGreen, colorReset) |
| 121 | + _, _ = fmt.Printf("You can now use the access token to make authenticated API requests.\n") |
| 122 | +} |
| 123 | + |
| 124 | +func requestDeviceCode(ctx context.Context, config *Config) (*DeviceCodeResponse, error) { |
| 125 | + // Use x/oauth2 clientcredentials config to structure the request |
| 126 | + // clientConfig := &clientcredentials.Config{ |
| 127 | + // ClientID: config.ClientID, |
| 128 | + // ClientSecret: config.ClientSecret, |
| 129 | + // TokenURL: config.BaseURL + "/oauth2/device", // Device code endpoint (RFC 8628) |
| 130 | + // } |
| 131 | + |
| 132 | + // Create form data for device code request |
| 133 | + data := url.Values{} |
| 134 | + data.Set("client_id", config.ClientID) |
| 135 | + |
| 136 | + // Optional: Add scope parameter |
| 137 | + // data.Set("scope", "openid profile") |
| 138 | + |
| 139 | + // Make the request to the device authorization endpoint |
| 140 | + req, err := http.NewRequestWithContext(ctx, "POST", config.BaseURL+"/oauth2/device", strings.NewReader(data.Encode())) |
| 141 | + if err != nil { |
| 142 | + return nil, xerrors.Errorf("creating request: %w", err) |
| 143 | + } |
| 144 | + |
| 145 | + // Set up basic auth with client credentials |
| 146 | + req.SetBasicAuth(config.ClientID, config.ClientSecret) |
| 147 | + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
| 148 | + |
| 149 | + client := &http.Client{Timeout: 30 * time.Second} |
| 150 | + resp, err := client.Do(req) |
| 151 | + if err != nil { |
| 152 | + return nil, xerrors.Errorf("making request: %w", err) |
| 153 | + } |
| 154 | + defer func() { _ = resp.Body.Close() }() |
| 155 | + |
| 156 | + if resp.StatusCode != http.StatusOK { |
| 157 | + var errResp ErrorResponse |
| 158 | + if err := json.NewDecoder(resp.Body).Decode(&errResp); err == nil { |
| 159 | + return nil, xerrors.Errorf("device code request failed: %s - %s", errResp.Error, errResp.ErrorDescription) |
| 160 | + } |
| 161 | + return nil, xerrors.Errorf("device code request failed with status %d", resp.StatusCode) |
| 162 | + } |
| 163 | + |
| 164 | + var deviceResp DeviceCodeResponse |
| 165 | + if err := json.NewDecoder(resp.Body).Decode(&deviceResp); err != nil { |
| 166 | + return nil, xerrors.Errorf("decoding response: %w", err) |
| 167 | + } |
| 168 | + |
| 169 | + return &deviceResp, nil |
| 170 | +} |
| 171 | + |
| 172 | +func pollForToken(ctx context.Context, config *Config, deviceResp *DeviceCodeResponse) (*TokenResponse, error) { |
| 173 | + // Use x/oauth2 config for token exchange |
| 174 | + oauth2Config := &oauth2.Config{ |
| 175 | + ClientID: config.ClientID, |
| 176 | + ClientSecret: config.ClientSecret, |
| 177 | + Endpoint: oauth2.Endpoint{ |
| 178 | + TokenURL: config.BaseURL + "/oauth2/token", |
| 179 | + }, |
| 180 | + } |
| 181 | + |
| 182 | + interval := time.Duration(deviceResp.Interval) * time.Second |
| 183 | + if interval < 5*time.Second { |
| 184 | + interval = 5 * time.Second // Minimum polling interval |
| 185 | + } |
| 186 | + |
| 187 | + deadline := time.Now().Add(time.Duration(deviceResp.ExpiresIn) * time.Second) |
| 188 | + ticker := time.NewTicker(interval) |
| 189 | + defer ticker.Stop() |
| 190 | + |
| 191 | + for { |
| 192 | + select { |
| 193 | + case <-ctx.Done(): |
| 194 | + return nil, ctx.Err() |
| 195 | + case <-ticker.C: |
| 196 | + if time.Now().After(deadline) { |
| 197 | + return nil, xerrors.New("device code expired") |
| 198 | + } |
| 199 | + |
| 200 | + _, _ = fmt.Printf("Polling for token...\n") |
| 201 | + |
| 202 | + // Create token exchange request using device_code grant |
| 203 | + data := url.Values{} |
| 204 | + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code") |
| 205 | + data.Set("device_code", deviceResp.DeviceCode) |
| 206 | + data.Set("client_id", config.ClientID) |
| 207 | + |
| 208 | + req, err := http.NewRequestWithContext(ctx, "POST", oauth2Config.Endpoint.TokenURL, strings.NewReader(data.Encode())) |
| 209 | + if err != nil { |
| 210 | + return nil, xerrors.Errorf("creating token request: %w", err) |
| 211 | + } |
| 212 | + |
| 213 | + req.SetBasicAuth(config.ClientID, config.ClientSecret) |
| 214 | + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
| 215 | + |
| 216 | + client := &http.Client{Timeout: 30 * time.Second} |
| 217 | + resp, err := client.Do(req) |
| 218 | + if err != nil { |
| 219 | + _, _ = fmt.Printf("Request error: %v\n", err) |
| 220 | + continue |
| 221 | + } |
| 222 | + |
| 223 | + var result map[string]interface{} |
| 224 | + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { |
| 225 | + _ = resp.Body.Close() |
| 226 | + _, _ = fmt.Printf("Decode error: %v\n", err) |
| 227 | + continue |
| 228 | + } |
| 229 | + _ = resp.Body.Close() |
| 230 | + |
| 231 | + if errorCode, ok := result["error"].(string); ok { |
| 232 | + switch errorCode { |
| 233 | + case "authorization_pending": |
| 234 | + _, _ = fmt.Printf("Authorization pending... continuing to poll\n") |
| 235 | + continue |
| 236 | + case "slow_down": |
| 237 | + _, _ = fmt.Printf("Slow down request - increasing polling interval by 5 seconds\n") |
| 238 | + interval += 5 * time.Second |
| 239 | + ticker.Reset(interval) |
| 240 | + continue |
| 241 | + case "access_denied": |
| 242 | + return nil, xerrors.New("access denied by user") |
| 243 | + case "expired_token": |
| 244 | + return nil, xerrors.New("device code expired") |
| 245 | + default: |
| 246 | + desc := "" |
| 247 | + if errorDesc, ok := result["error_description"].(string); ok { |
| 248 | + desc = " - " + errorDesc |
| 249 | + } |
| 250 | + return nil, xerrors.Errorf("token error: %s%s", errorCode, desc) |
| 251 | + } |
| 252 | + } |
| 253 | + |
| 254 | + // Success case - convert to TokenResponse |
| 255 | + var tokenResp TokenResponse |
| 256 | + if accessToken, ok := result["access_token"].(string); ok { |
| 257 | + tokenResp.AccessToken = accessToken |
| 258 | + } |
| 259 | + if tokenType, ok := result["token_type"].(string); ok { |
| 260 | + tokenResp.TokenType = tokenType |
| 261 | + } |
| 262 | + if expiresIn, ok := result["expires_in"].(float64); ok { |
| 263 | + tokenResp.ExpiresIn = int(expiresIn) |
| 264 | + } |
| 265 | + if refreshToken, ok := result["refresh_token"].(string); ok { |
| 266 | + tokenResp.RefreshToken = refreshToken |
| 267 | + } |
| 268 | + if scope, ok := result["scope"].(string); ok { |
| 269 | + tokenResp.Scope = scope |
| 270 | + } |
| 271 | + |
| 272 | + if tokenResp.AccessToken == "" { |
| 273 | + return nil, xerrors.New("no access token in response") |
| 274 | + } |
| 275 | + |
| 276 | + return &tokenResp, nil |
| 277 | + } |
| 278 | + } |
| 279 | +} |
| 280 | + |
| 281 | +func testAccessToken(ctx context.Context, config *Config, accessToken string) error { |
| 282 | + req, err := http.NewRequestWithContext(ctx, "GET", config.BaseURL+"/api/v2/users/me", nil) |
| 283 | + if err != nil { |
| 284 | + return xerrors.Errorf("creating request: %w", err) |
| 285 | + } |
| 286 | + |
| 287 | + req.Header.Set("Coder-Session-Token", accessToken) |
| 288 | + |
| 289 | + client := &http.Client{Timeout: 10 * time.Second} |
| 290 | + resp, err := client.Do(req) |
| 291 | + if err != nil { |
| 292 | + return xerrors.Errorf("making request: %w", err) |
| 293 | + } |
| 294 | + defer func() { _ = resp.Body.Close() }() |
| 295 | + |
| 296 | + if resp.StatusCode != http.StatusOK { |
| 297 | + return xerrors.Errorf("API request failed with status %d", resp.StatusCode) |
| 298 | + } |
| 299 | + |
| 300 | + var userInfo map[string]interface{} |
| 301 | + if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { |
| 302 | + return xerrors.Errorf("decoding response: %w", err) |
| 303 | + } |
| 304 | + |
| 305 | + _, _ = fmt.Printf("%sAPI Test Response:%s\n", colorGreen, colorReset) |
| 306 | + prettyJSON, _ := json.MarshalIndent(userInfo, "", " ") |
| 307 | + _, _ = fmt.Printf("%s\n", prettyJSON) |
| 308 | + |
| 309 | + return nil |
| 310 | +} |
| 311 | + |
| 312 | +func getEnvOrDefault(key, defaultValue string) string { |
| 313 | + if value := os.Getenv(key); value != "" { |
| 314 | + return value |
| 315 | + } |
| 316 | + return defaultValue |
| 317 | +} |
0 commit comments