File: //snap/google-cloud-cli/396/platform/bq/auth/gcloud_credential_loader.py
#!/usr/bin/env python
"""Utilities to load Google Auth credentials from gcloud."""
import datetime
import logging
import subprocess
from typing import Iterator, List, Optional
from google.oauth2 import credentials as google_oauth2
import bq_auth_flags
import bq_flags
import bq_utils
from auth import utils as bq_auth_utils
from gcloud_wrapper import gcloud_runner
from utils import bq_error
from utils import bq_gcloud_utils
ERROR_TEXT_PRODUCED_IF_GCLOUD_NOT_FOUND = "No such file or directory: 'gcloud'"
_GDRIVE_SCOPE = 'https://www.googleapis.com/auth/drive'
_GCP_SCOPE = 'https://www.googleapis.com/auth/cloud-platform'
def LoadCredential() -> google_oauth2.Credentials:
"""Loads credentials by calling gcloud commands."""
gcloud_config = bq_gcloud_utils.load_config()
account = gcloud_config.get('core', {}).get('account', '')
logging.info('Loading auth credentials from gcloud for account: %s', account)
is_service_account = bq_utils.IsServiceAccount(account)
access_token = _GetAccessTokenAndPrintOutput(is_service_account)
# Service accounts use the refresh_handler instead of the token for refresh.
refresh_token = (
None if is_service_account else _GetRefreshTokenAndPrintOutput()
)
refresh_handler = (
_ServiceAccountRefreshHandler if is_service_account else None
)
fallback_quota_project_id = _GetFallbackQuotaProjectId(
is_service_account=is_service_account,
has_refresh_token=refresh_token is not None,
)
return google_oauth2.Credentials(
account=account,
token=access_token,
refresh_token=refresh_token,
refresh_handler=refresh_handler,
client_id=bq_auth_utils.get_client_id(),
client_secret=bq_auth_utils.get_client_secret(),
token_uri=bq_auth_utils.get_token_uri(),
quota_project_id=bq_utils.GetResolvedQuotaProjectID(
bq_auth_flags.QUOTA_PROJECT_ID.value, fallback_quota_project_id
),
)
def _GetScopes() -> List[str]:
scopes = []
if bq_flags.ENABLE_GDRIVE.value:
drive_scope = _GDRIVE_SCOPE
scopes.extend([drive_scope, _GCP_SCOPE])
return scopes
def _GetAccessTokenAndPrintOutput(
is_service_account: bool, scopes: Optional[List[str]] = None
) -> Optional[str]:
scopes = _GetScopes() if scopes is None else scopes
if is_service_account and scopes:
return _GetTokenFromGcloudAndPrintOtherOutput(
['auth', 'print-access-token', '--scopes', ','.join(scopes)]
)
return _GetTokenFromGcloudAndPrintOtherOutput(['auth', 'print-access-token'])
def _GetRefreshTokenAndPrintOutput() -> Optional[str]:
return _GetTokenFromGcloudAndPrintOtherOutput(['auth', 'print-refresh-token'])
def _GetTokenFromGcloudAndPrintOtherOutput(
cmd: List[str],
stderr: Optional[int] = subprocess.STDOUT,
) -> Optional[str]:
"""Returns a token or prints other messages from the given gcloud command."""
try:
token = None
for output in _RunGcloudCommand(cmd, stderr):
if output and ' ' not in output:
# Token is a non-empty string of non-space characters.
token = output
break
else:
print(output)
return token
except bq_error.BigqueryError as e:
single_line_error_msg = str(e).replace('\n', '')
if 'security key' in single_line_error_msg:
raise bq_error.BigqueryError(
'Access token has expired. Did you touch the security key within the'
' timeout window?\n'
+ _GetReauthMessage()
)
elif 'Refresh token has expired' in single_line_error_msg:
raise bq_error.BigqueryError(
'Refresh token has expired. ' + _GetReauthMessage()
)
elif 'do not support refresh tokens' in single_line_error_msg:
# It's expected that certain credential types don't support refresh token.
return None
else:
raise bq_error.BigqueryError(
'Error retrieving auth credentials from gcloud: %s'
% _UpdateReauthMessage(str(e))
)
except Exception as e: # pylint: disable=broad-exception-caught
single_line_error_msg = str(e).replace('\n', '')
if ERROR_TEXT_PRODUCED_IF_GCLOUD_NOT_FOUND in single_line_error_msg:
raise bq_error.BigqueryError(
"'gcloud' not found but is required for authentication. To install,"
' follow these instructions:'
' https://cloud.google.com/sdk/docs/install'
)
raise bq_error.BigqueryError(
'Error retrieving auth credentials from gcloud: %s' % str(e)
)
def _RunGcloudCommand(
cmd: List[str], stderr: Optional[int] = subprocess.STDOUT
) -> Iterator[str]:
"""Runs the given gcloud command, yields the output, and returns the final status code."""
proc = gcloud_runner.run_gcloud_command(cmd, stderr=stderr)
error_msgs = []
if proc.stdout:
for stdout_line in iter(proc.stdout.readline, ''):
line = str(stdout_line).strip()
if line.startswith('ERROR:') or error_msgs:
error_msgs.append(line)
else:
yield line
proc.stdout.close()
return_code = proc.wait()
if return_code:
raise bq_error.BigqueryError('\n'.join(error_msgs))
def _GetReauthMessage() -> str:
gcloud_command = '$ gcloud auth login' + (
' --enable-gdrive-access' if bq_flags.ENABLE_GDRIVE.value else ''
)
return 'To re-authenticate, run:\n\n%s' % gcloud_command
def _UpdateReauthMessage(message: str) -> str:
if '$ gcloud auth login' not in message or not bq_flags.ENABLE_GDRIVE.value:
return message
return message.replace(
'$ gcloud auth login',
'$ gcloud auth login --enable-gdrive-access',
)
def _GetFallbackQuotaProjectId(
is_service_account: bool, has_refresh_token: bool
) -> Optional[str]:
# When the credential type is not a service account - determined by the
# account name or whether we can get a non-empty refresh token - set a
# fallback quota project ID to be the resource project ID. When the credential
# type is a service account, don't set any fallback quota project ID.
if is_service_account:
return None
if not has_refresh_token:
return None
return bq_flags.PROJECT_ID.value
def _ServiceAccountRefreshHandler(request, scopes):
"""Refreshes the access token for a service account."""
del request # Unused.
access_token = _GetAccessTokenAndPrintOutput(
is_service_account=True, scopes=scopes
)
# According to
# https://cloud.google.com/docs/authentication/token-types#at-lifetime
# and https://cloud.google.com/sdk/gcloud/reference/auth/print-access-token,
# the access token lifetime from gcloud auth print-access-token is 1 hour,
# but set token expiry to 55 minutes from now to be safe.
expiry = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
minutes=55
)
expiry = expiry.replace(tzinfo=None)
return access_token, expiry