#------------------------------------------------------------------------------ # # Copyright (c) Microsoft Corporation. # All rights reserved. # # This code is licensed under the MIT License. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files(the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and / or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions : # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # #------------------------------------------------------------------------------ from datetime import datetime, timedelta import math import re import json import time import uuid try: from urllib.parse import urlencode, urlparse except ImportError: from urllib import urlencode # pylint: disable=no-name-in-module from urlparse import urlparse # pylint: disable=import-error,ungrouped-imports import requests from . import log from . import util from .constants import OAuth2, TokenResponseFields, IdTokenFields from .adal_error import AdalError TOKEN_RESPONSE_MAP = { OAuth2.ResponseParameters.TOKEN_TYPE : TokenResponseFields.TOKEN_TYPE, OAuth2.ResponseParameters.ACCESS_TOKEN : TokenResponseFields.ACCESS_TOKEN, OAuth2.ResponseParameters.REFRESH_TOKEN : TokenResponseFields.REFRESH_TOKEN, OAuth2.ResponseParameters.CREATED_ON : TokenResponseFields.CREATED_ON, OAuth2.ResponseParameters.EXPIRES_ON : TokenResponseFields.EXPIRES_ON, OAuth2.ResponseParameters.EXPIRES_IN : TokenResponseFields.EXPIRES_IN, OAuth2.ResponseParameters.RESOURCE : TokenResponseFields.RESOURCE, OAuth2.ResponseParameters.ERROR : TokenResponseFields.ERROR, OAuth2.ResponseParameters.ERROR_DESCRIPTION : TokenResponseFields.ERROR_DESCRIPTION, } _REQ_OPTION = {'headers' : {'content-type': 'application/x-www-form-urlencoded'}} _ERROR_TEMPLATE = u"{} request returned http error: {}" def map_fields(in_obj, map_to): return dict((map_to[k], v) for k, v in in_obj.items() if k in map_to) def _get_user_id(id_token): user_id = None is_displayable = False if id_token.get('upn'): user_id = id_token['upn'] is_displayable = True elif id_token.get('email'): user_id = id_token['email'] is_displayable = True elif id_token.get('sub'): user_id = id_token['sub'] if not user_id: user_id = str(uuid.uuid4()) user_id_vals = {} user_id_vals[IdTokenFields.USER_ID] = user_id if is_displayable: user_id_vals[IdTokenFields.IS_USER_ID_DISPLAYABLE] = True return user_id_vals def _extract_token_values(id_token): extracted_values = {} extracted_values = map_fields(id_token, OAuth2.IdTokenMap) extracted_values.update(_get_user_id(id_token)) return extracted_values class OAuth2Client(object): def __init__(self, call_context, authority): self._token_endpoint = authority.token_endpoint self._device_code_endpoint = authority.device_code_endpoint self._log = log.Logger("OAuth2Client", call_context['log_context']) self._call_context = call_context self._cancel_polling_request = False def _create_token_url(self): parameters = {} if self._call_context.get('api_version'): parameters[OAuth2.Parameters.AAD_API_VERSION] = self._call_context[ 'api_version'] return urlparse('{}?{}'.format(self._token_endpoint, urlencode(parameters))) def _create_device_code_url(self): parameters = {} parameters[OAuth2.Parameters.AAD_API_VERSION] = '1.0' return urlparse('{}?{}'.format(self._device_code_endpoint, urlencode(parameters))) def _parse_optional_ints(self, obj, keys): for key in keys: try: obj[key] = int(obj[key]) except ValueError: self._log.exception("%(key)s could not be parsed as an int", {"key": key}) raise except KeyError: # if the key isn't present we can just continue pass def _parse_id_token(self, encoded_token): cracked_token = self._open_jwt(encoded_token) if not cracked_token: return try: b64_id_token = cracked_token['JWSPayload'] b64_decoded = util.base64_urlsafe_decode(b64_id_token) if not b64_decoded: self._log.warn('The returned id_token could not be base64 url safe decoded.') return id_token = json.loads(b64_decoded.decode('utf-8')) except ValueError: self._log.exception( "The returned id_token could not be decoded: %(id_token)s", {"id_token": encoded_token}) raise return _extract_token_values(id_token) def _open_jwt(self, jwt_token): id_token_parts_reg = r"^([^\.\s]*)\.([^\.\s]+)\.([^\.\s]*)$" matches = re.search(id_token_parts_reg, jwt_token) if not matches or len(matches.groups()) < 3: self._log.warn('The token was not parsable.') return {} return { 'header': matches.group(1), 'JWSPayload': matches.group(2), 'JWSSig': matches.group(3) } def _validate_token_response(self, body): try: wire_response = json.loads(body) except ValueError: self._log.exception( 'The token response from the server is unparseable as JSON: %(token_response)s', {"token_response": body}) raise int_keys = [ OAuth2.ResponseParameters.EXPIRES_ON, OAuth2.ResponseParameters.EXPIRES_IN, OAuth2.ResponseParameters.CREATED_ON ] self._parse_optional_ints(wire_response, int_keys) expires_in = wire_response.get(OAuth2.ResponseParameters.EXPIRES_IN) if expires_in: now = datetime.now() soon = timedelta(seconds=expires_in) wire_response[OAuth2.ResponseParameters.EXPIRES_ON] = str(now + soon) created_on = wire_response.get(OAuth2.ResponseParameters.CREATED_ON) if created_on: temp_date = datetime.fromtimestamp(created_on) wire_response[OAuth2.ResponseParameters.CREATED_ON] = str(temp_date) if not wire_response.get(OAuth2.ResponseParameters.TOKEN_TYPE): raise AdalError('wire_response is missing token_type', wire_response) if not wire_response.get(OAuth2.ResponseParameters.ACCESS_TOKEN): raise AdalError('wire_response is missing access_token', wire_response) token_response = map_fields(wire_response, TOKEN_RESPONSE_MAP) if wire_response.get(OAuth2.ResponseParameters.ID_TOKEN): id_token = self._parse_id_token(wire_response[OAuth2.ResponseParameters.ID_TOKEN]) if id_token: token_response.update(id_token) return token_response def _validate_device_code_response(self, body): try: wire_response = json.loads(body) except ValueError: self._log.info('The device code response returned from the server is unparseable as JSON:') raise int_keys = [ OAuth2.DeviceCodeResponseParameters.EXPIRES_IN, OAuth2.DeviceCodeResponseParameters.INTERVAL ] self._parse_optional_ints(wire_response, int_keys) if not wire_response.get(OAuth2.DeviceCodeResponseParameters.EXPIRES_IN): raise AdalError('wire_response is missing expires_in', wire_response) if not wire_response.get(OAuth2.DeviceCodeResponseParameters.DEVICE_CODE): raise AdalError('wire_response is missing device_code', wire_response) if not wire_response.get(OAuth2.DeviceCodeResponseParameters.USER_CODE): raise AdalError('wire_response is missing user_code', wire_response) #skip field naming tweak, becasue names from wire are python style already return wire_response def _handle_get_token_response(self, body): try: return self._validate_token_response(body) except Exception: self._log.exception( "Error validating get token response: %(token_response)s", {"token_response": body}) raise def _handle_get_device_code_response(self, body): try: return self._validate_device_code_response(body) except Exception: self._log.exception( "Error validating get user code response: %(token_response)s", {"token_response": body}) raise def get_token(self, oauth_parameters): token_url = self._create_token_url() url_encoded_token_request = urlencode(oauth_parameters) post_options = util.create_request_options(self, _REQ_OPTION) operation = "Get Token" try: resp = requests.post(token_url.geturl(), data=url_encoded_token_request, headers=post_options['headers'], verify=self._call_context.get('verify_ssl', None), proxies=self._call_context.get('proxies', None), timeout=self._call_context.get('timeout', None)) util.log_return_correlation_id(self._log, operation, resp) except Exception: self._log.exception("%(operation)s request failed", {"operation": operation}) raise if util.is_http_success(resp.status_code): return self._handle_get_token_response(resp.text) else: if resp.status_code == 429: resp.raise_for_status() # Will raise requests.exceptions.HTTPError return_error_string = _ERROR_TEMPLATE.format(operation, resp.status_code) error_response = "" if resp.text: return_error_string = u"{} and server response: {}".format(return_error_string, resp.text) try: error_response = resp.json() except ValueError: pass raise AdalError(return_error_string, error_response) def get_user_code_info(self, oauth_parameters): device_code_url = self._create_device_code_url() url_encoded_code_request = urlencode(oauth_parameters) post_options = util.create_request_options(self, _REQ_OPTION) operation = "Get Device Code" try: resp = requests.post(device_code_url.geturl(), data=url_encoded_code_request, headers=post_options['headers'], verify=self._call_context.get('verify_ssl', None), proxies=self._call_context.get('proxies', None), timeout=self._call_context.get('timeout', None)) util.log_return_correlation_id(self._log, operation, resp) except Exception: self._log.exception("%(operation)s request failed", {"operation": operation}) raise if util.is_http_success(resp.status_code): user_code_info = self._handle_get_device_code_response(resp.text) user_code_info['correlation_id'] = resp.headers.get('client-request-id') return user_code_info else: if resp.status_code == 429: resp.raise_for_status() # Will raise requests.exceptions.HTTPError return_error_string = _ERROR_TEMPLATE.format(operation, resp.status_code) error_response = "" if resp.text: return_error_string = u"{} and server response: {}".format(return_error_string, resp.text) try: error_response = resp.json() except ValueError: pass raise AdalError(return_error_string, error_response) def get_token_with_polling(self, oauth_parameters, refresh_internal, expires_in): token_url = self._create_token_url() url_encoded_code_request = urlencode(oauth_parameters) post_options = util.create_request_options(self, _REQ_OPTION) operation = "Get token with device code" max_times_for_retry = math.floor(expires_in/refresh_internal) for _ in range(int(max_times_for_retry)): if self._cancel_polling_request: raise AdalError('Polling_Request_Cancelled') resp = requests.post( token_url.geturl(), data=url_encoded_code_request, headers=post_options['headers'], proxies=self._call_context.get('proxies', None), verify=self._call_context.get('verify_ssl', None)) if resp.status_code == 429: resp.raise_for_status() # Will raise requests.exceptions.HTTPError util.log_return_correlation_id(self._log, operation, resp) wire_response = {} if not util.is_http_success(resp.status_code): # on error, the body should be json already wire_response = json.loads(resp.text) error = wire_response.get(OAuth2.DeviceCodeResponseParameters.ERROR) if error == 'authorization_pending': time.sleep(refresh_internal) continue elif error: raise AdalError('Unexpected polling state {}'.format(error), wire_response) else: try: return self._validate_token_response(resp.text) except Exception: self._log.exception( u"Error validating get token response %(access_token)s", {"access_token": resp.text}) raise raise AdalError('Timeout from "get_token_with_polling"') def cancel_polling_request(self): self._cancel_polling_request = True