package oauth2 import ( "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "strings" "time" "golang.org/x/oauth2/internal" ) // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5 const ( errAuthorizationPending = "authorization_pending" errSlowDown = "slow_down" errAccessDenied = "access_denied" errExpiredToken = "expired_token" ) // DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response // https://datatracker.ietf.org/doc/html/rfc8628#section-3.2 type DeviceAuthResponse struct { // DeviceCode DeviceCode string `json:"device_code"` // UserCode is the code the user should enter at the verification uri UserCode string `json:"user_code"` // VerificationURI is where user should enter the user code VerificationURI string `json:"verification_uri"` // VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code. VerificationURIComplete string `json:"verification_uri_complete,omitempty"` // Expiry is when the device code and user code expire Expiry time.Time `json:"expires_in,omitempty"` // Interval is the duration in seconds that Poll should wait between requests Interval int64 `json:"interval,omitempty"` } func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) { type Alias DeviceAuthResponse var expiresIn int64 if !d.Expiry.IsZero() { expiresIn = int64(time.Until(d.Expiry).Seconds()) } return json.Marshal(&struct { ExpiresIn int64 `json:"expires_in,omitempty"` *Alias }{ ExpiresIn: expiresIn, Alias: (*Alias)(&d), }) } func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error { type Alias DeviceAuthResponse aux := &struct { ExpiresIn int64 `json:"expires_in"` // workaround misspelling of verification_uri VerificationURL string `json:"verification_url"` *Alias }{ Alias: (*Alias)(c), } if err := json.Unmarshal(data, &aux); err != nil { return err } if aux.ExpiresIn != 0 { c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn)) } if c.VerificationURI == "" { c.VerificationURI = aux.VerificationURL } return nil } // DeviceAuth returns a device auth struct which contains a device code // and authorization information provided for users to enter on another device. func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) { // https://datatracker.ietf.org/doc/html/rfc8628#section-3.1 v := url.Values{ "client_id": {c.ClientID}, } if len(c.Scopes) > 0 { v.Set("scope", strings.Join(c.Scopes, " ")) } for _, opt := range opts { opt.setValue(v) } return retrieveDeviceAuth(ctx, c, v) } func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) { if c.Endpoint.DeviceAuthURL == "" { return nil, errors.New("endpoint missing DeviceAuthURL") } req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode())) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") t := time.Now() r, err := internal.ContextClient(ctx).Do(req) if err != nil { return nil, err } body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) if err != nil { return nil, fmt.Errorf("oauth2: cannot auth device: %v", err) } if code := r.StatusCode; code < 200 || code > 299 { return nil, &RetrieveError{ Response: r, Body: body, } } da := &DeviceAuthResponse{} err = json.Unmarshal(body, &da) if err != nil { return nil, fmt.Errorf("unmarshal %s", err) } if !da.Expiry.IsZero() { // Make a small adjustment to account for time taken by the request da.Expiry = da.Expiry.Add(-time.Since(t)) } return da, nil } // DeviceAccessToken polls the server to exchange a device code for a token. func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) { if !da.Expiry.IsZero() { var cancel context.CancelFunc ctx, cancel = context.WithDeadline(ctx, da.Expiry) defer cancel() } // https://datatracker.ietf.org/doc/html/rfc8628#section-3.4 v := url.Values{ "client_id": {c.ClientID}, "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, "device_code": {da.DeviceCode}, } if len(c.Scopes) > 0 { v.Set("scope", strings.Join(c.Scopes, " ")) } for _, opt := range opts { opt.setValue(v) } // "If no value is provided, clients MUST use 5 as the default." // https://datatracker.ietf.org/doc/html/rfc8628#section-3.2 interval := da.Interval if interval == 0 { interval = 5 } ticker := time.NewTicker(time.Duration(interval) * time.Second) defer ticker.Stop() for { select { case <-ctx.Done(): return nil, ctx.Err() case <-ticker.C: tok, err := retrieveToken(ctx, c, v) if err == nil { return tok, nil } e, ok := err.(*RetrieveError) if !ok { return nil, err } switch e.ErrorCode { case errSlowDown: // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5 // "the interval MUST be increased by 5 seconds for this and all subsequent requests" interval += 5 ticker.Reset(time.Duration(interval) * time.Second) case errAuthorizationPending: // Do nothing. case errAccessDenied, errExpiredToken: fallthrough default: return tok, err } } } }