File: //snap/google-cloud-cli/394/lib/googlecloudsdk/api_lib/dataproc/util.py
# -*- coding: utf-8 -*- #
# Copyright 2015 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for the gcloud dataproc tool."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import base64
import hashlib
import json
import os
import subprocess
import tempfile
import time
import uuid
from apitools.base.py import encoding
from apitools.base.py import exceptions as apitools_exceptions
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.dataproc import exceptions
from googlecloudsdk.api_lib.dataproc import storage_helpers
from googlecloudsdk.calliope import arg_parsers
from googlecloudsdk.command_lib.export import util as export_util
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import requests
from googlecloudsdk.core.console import console_attr
from googlecloudsdk.core.console import console_io
from googlecloudsdk.core.console import progress_tracker
from googlecloudsdk.core.credentials import creds as c_creds
from googlecloudsdk.core.credentials import store as c_store
from googlecloudsdk.core.util import retry
import six
SCHEMA_DIR = os.path.join(os.path.dirname(__file__), 'schemas')
def FormatRpcError(error):
"""Returns a printable representation of a failed Google API's status.proto.
Args:
error: the failed Status to print.
Returns:
A ready-to-print string representation of the error.
"""
log.debug('Error:\n' + encoding.MessageToJson(error))
return error.message
def WaitForResourceDeletion(request_method,
resource_ref,
message,
timeout_s=60,
poll_period_s=5):
"""Poll Dataproc resource until it no longer exists."""
with progress_tracker.ProgressTracker(message, autotick=True):
start_time = time.time()
while timeout_s > (time.time() - start_time):
try:
request_method(resource_ref)
except apitools_exceptions.HttpNotFoundError:
# Object deleted
return
except apitools_exceptions.HttpError as error:
log.debug('Get request for [{0}] failed:\n{1}', resource_ref, error)
# Do not retry on 4xx errors
if IsClientHttpException(error):
raise
time.sleep(poll_period_s)
raise exceptions.OperationTimeoutError(
'Deleting resource [{0}] timed out.'.format(resource_ref))
def GetUniqueId():
return uuid.uuid4().hex
class Bunch(object):
"""Class that converts a dictionary to javascript like object.
For example:
Bunch({'a': {'b': {'c': 0}}}).a.b.c == 0
"""
def __init__(self, dictionary):
for key, value in six.iteritems(dictionary):
if isinstance(value, dict):
value = Bunch(value)
self.__dict__[key] = value
def AddJvmDriverFlags(parser):
parser.add_argument(
'--jar',
dest='main_jar',
help='The HCFS URI of jar file containing the driver jar.')
parser.add_argument(
'--class',
dest='main_class',
help=('The class containing the main method of the driver. Must be in a'
' provided jar or jar that is already on the classpath'))
def IsClientHttpException(http_exception):
"""Returns true if the http exception given is an HTTP 4xx error."""
return http_exception.status_code >= 400 and http_exception.status_code < 500
# TODO(b/36056506): Use api_lib.utils.waiter
def WaitForOperation(dataproc, operation, message, timeout_s, poll_period_s=5):
"""Poll dataproc Operation until its status is done or timeout reached.
Args:
dataproc: wrapper for Dataproc messages, resources, and client
operation: Operation, message of the operation to be polled.
message: str, message to display to user while polling.
timeout_s: number, seconds to poll with retries before timing out.
poll_period_s: number, delay in seconds between requests.
Returns:
Operation: the return value of the last successful operations.get
request.
Raises:
OperationError: if the operation times out or finishes with an error.
"""
request = dataproc.messages.DataprocProjectsRegionsOperationsGetRequest(
name=operation.name)
log.status.Print('Waiting on operation [{0}].'.format(operation.name))
start_time = time.time()
warnings_so_far = 0
is_tty = console_io.IsInteractive(error=True)
tracker_separator = '\n' if is_tty else ''
def _LogWarnings(warnings):
new_warnings = warnings[warnings_so_far:]
if new_warnings:
# Drop a line to print nicely with the progress tracker.
log.err.write(tracker_separator)
for warning in new_warnings:
log.warning(warning)
with progress_tracker.ProgressTracker(message, autotick=True):
while timeout_s > (time.time() - start_time):
try:
operation = dataproc.client.projects_regions_operations.Get(request)
metadata = ParseOperationJsonMetadata(
operation.metadata, dataproc.messages.ClusterOperationMetadata)
_LogWarnings(metadata.warnings)
warnings_so_far = len(metadata.warnings)
if operation.done:
break
except apitools_exceptions.HttpError as http_exception:
# Do not retry on 4xx errors.
if IsClientHttpException(http_exception):
raise
time.sleep(poll_period_s)
metadata = ParseOperationJsonMetadata(
operation.metadata, dataproc.messages.ClusterOperationMetadata)
_LogWarnings(metadata.warnings)
if not operation.done:
raise exceptions.OperationTimeoutError('Operation [{0}] timed out.'.format(
operation.name))
elif operation.error:
raise exceptions.OperationError('Operation [{0}] failed: {1}.'.format(
operation.name, FormatRpcError(operation.error)))
log.info('Operation [%s] finished after %.3f seconds', operation.name,
(time.time() - start_time))
return operation
def PrintWorkflowMetadata(metadata, status, operations, errors):
"""Print workflow and job status for the running workflow template.
This method will detect any changes of state in the latest metadata and print
all the new states in a workflow template.
For example:
Workflow template template-name RUNNING
Creating cluster: Operation ID create-id.
Job ID job-id-1 RUNNING
Job ID job-id-1 COMPLETED
Deleting cluster: Operation ID delete-id.
Workflow template template-name DONE
Args:
metadata: Dataproc WorkflowMetadata message object, contains the latest
states of a workflow template.
status: Dictionary, stores all jobs' status in the current workflow
template, as well as the status of the overarching workflow.
operations: Dictionary, stores cluster operation status for the workflow
template.
errors: Dictionary, stores errors from the current workflow template.
"""
# Key chosen to avoid collision with job ids, which are at least 3 characters.
template_key = 'wt'
if template_key not in status or metadata.state != status[template_key]:
if metadata.template is not None:
log.status.Print('WorkflowTemplate [{0}] {1}'.format(
metadata.template, metadata.state))
else:
# Workflows instantiated inline do not store an id in their metadata.
log.status.Print('WorkflowTemplate {0}'.format(metadata.state))
status[template_key] = metadata.state
if metadata.createCluster != operations['createCluster']:
if hasattr(metadata.createCluster,
'error') and metadata.createCluster.error is not None:
log.status.Print(metadata.createCluster.error)
elif hasattr(metadata.createCluster,
'done') and metadata.createCluster.done is not None:
log.status.Print('Created cluster: {0}.'.format(metadata.clusterName))
elif hasattr(
metadata.createCluster,
'operationId') and metadata.createCluster.operationId is not None:
log.status.Print('Creating cluster: Operation ID [{0}].'.format(
metadata.createCluster.operationId))
operations['createCluster'] = metadata.createCluster
if hasattr(metadata.graph, 'nodes'):
for node in metadata.graph.nodes:
if not node.jobId:
continue
if node.jobId not in status or status[node.jobId] != node.state:
log.status.Print('Job ID {0} {1}'.format(node.jobId, node.state))
status[node.jobId] = node.state
if node.error and (node.jobId not in errors or
errors[node.jobId] != node.error):
log.status.Print('Job ID {0} error: {1}'.format(node.jobId, node.error))
errors[node.jobId] = node.error
if metadata.deleteCluster != operations['deleteCluster']:
if hasattr(metadata.deleteCluster,
'error') and metadata.deleteCluster.error is not None:
log.status.Print(metadata.deleteCluster.error)
elif hasattr(metadata.deleteCluster,
'done') and metadata.deleteCluster.done is not None:
log.status.Print('Deleted cluster: {0}.'.format(metadata.clusterName))
elif hasattr(
metadata.deleteCluster,
'operationId') and metadata.deleteCluster.operationId is not None:
log.status.Print('Deleting cluster: Operation ID [{0}].'.format(
metadata.deleteCluster.operationId))
operations['deleteCluster'] = metadata.deleteCluster
# TODO(b/36056506): Use api_lib.utils.waiter
def WaitForWorkflowTemplateOperation(dataproc,
operation,
timeout_s=None,
poll_period_s=5):
"""Poll dataproc Operation until its status is done or timeout reached.
Args:
dataproc: wrapper for Dataproc messages, resources, and client
operation: Operation, message of the operation to be polled.
timeout_s: number, seconds to poll with retries before timing out.
poll_period_s: number, delay in seconds between requests.
Returns:
Operation: the return value of the last successful operations.get
request.
Raises:
OperationError: if the operation times out or finishes with an error.
"""
request = dataproc.messages.DataprocProjectsRegionsOperationsGetRequest(
name=operation.name)
log.status.Print('Waiting on operation [{0}].'.format(operation.name))
start_time = time.time()
operations = {'createCluster': None, 'deleteCluster': None}
status = {}
errors = {}
# If no timeout is specified, poll forever.
while timeout_s is None or timeout_s > (time.time() - start_time):
try:
operation = dataproc.client.projects_regions_operations.Get(request)
metadata = ParseOperationJsonMetadata(operation.metadata,
dataproc.messages.WorkflowMetadata)
PrintWorkflowMetadata(metadata, status, operations, errors)
if operation.done:
break
except apitools_exceptions.HttpError as http_exception:
# Do not retry on 4xx errors.
if IsClientHttpException(http_exception):
raise
time.sleep(poll_period_s)
metadata = ParseOperationJsonMetadata(operation.metadata,
dataproc.messages.WorkflowMetadata)
if not operation.done:
raise exceptions.OperationTimeoutError('Operation [{0}] timed out.'.format(
operation.name))
elif operation.error:
raise exceptions.OperationError('Operation [{0}] failed: {1}.'.format(
operation.name, FormatRpcError(operation.error)))
for op in ['createCluster', 'deleteCluster']:
if op in operations and operations[op] is not None and operations[op].error:
raise exceptions.OperationError('Operation [{0}] failed: {1}.'.format(
operations[op].operationId, operations[op].error))
log.info('Operation [%s] finished after %.3f seconds', operation.name,
(time.time() - start_time))
return operation
class NoOpProgressDisplay(object):
"""For use in place of a ProgressTracker in a 'with' block."""
def __enter__(self):
pass
def __exit__(self, *unused_args):
pass
def WaitForJobTermination(dataproc,
job,
job_ref,
message,
goal_state,
error_state=None,
stream_driver_log=False,
log_poll_period_s=1,
dataproc_poll_period_s=10,
timeout_s=None):
"""Poll dataproc Job until its status is terminal or timeout reached.
Args:
dataproc: wrapper for dataproc resources, client and messages
job: The job to wait to finish.
job_ref: Parsed dataproc.projects.regions.jobs resource containing a
projectId, region, and jobId.
message: str, message to display to user while polling.
goal_state: JobStatus.StateValueValuesEnum, the state to define success
error_state: JobStatus.StateValueValuesEnum, the state to define failure
stream_driver_log: bool, Whether to show the Job's driver's output.
log_poll_period_s: number, delay in seconds between checking on the log.
dataproc_poll_period_s: number, delay in seconds between requests to the
Dataproc API.
timeout_s: number, time out for job completion. None means no timeout.
Returns:
Job: the return value of the last successful jobs.get request.
Raises:
JobError: if the job finishes with an error.
"""
request = dataproc.messages.DataprocProjectsRegionsJobsGetRequest(
projectId=job_ref.projectId, region=job_ref.region, jobId=job_ref.jobId)
driver_log_stream = None
last_job_poll_time = 0
job_complete = False
wait_display = None
driver_output_uri = None
def ReadDriverLogIfPresent():
if driver_log_stream and driver_log_stream.open:
# TODO(b/36049794): Don't read all output.
driver_log_stream.ReadIntoWritable(log.err)
def PrintEqualsLine():
attr = console_attr.GetConsoleAttr()
log.err.Print('=' * attr.GetTermSize()[0])
if stream_driver_log:
log.status.Print('Waiting for job output...')
wait_display = NoOpProgressDisplay()
else:
wait_display = progress_tracker.ProgressTracker(message, autotick=True)
start_time = now = time.time()
with wait_display:
while not timeout_s or timeout_s > (now - start_time):
# Poll logs first to see if it closed.
ReadDriverLogIfPresent()
log_stream_closed = driver_log_stream and not driver_log_stream.open
if (not job_complete and
job.status.state in dataproc.terminal_job_states):
job_complete = True
# Wait an 10s to get trailing output.
timeout_s = now - start_time + 10
if job_complete and (not stream_driver_log or log_stream_closed):
# Nothing left to wait for
break
regular_job_poll = (
not job_complete
# Poll less frequently on dataproc API
and now >= last_job_poll_time + dataproc_poll_period_s)
# Poll at regular frequency before output has streamed and after it has
# finished.
expecting_output_stream = stream_driver_log and not driver_log_stream
expecting_job_done = not job_complete and log_stream_closed
if regular_job_poll or expecting_output_stream or expecting_job_done:
last_job_poll_time = now
try:
job = dataproc.client.projects_regions_jobs.Get(request)
except apitools_exceptions.HttpError as error:
log.warning('GetJob failed:\n{}'.format(six.text_type(error)))
# Do not retry on 4xx errors.
if IsClientHttpException(error):
raise
if (stream_driver_log and job.driverOutputResourceUri and
job.driverOutputResourceUri != driver_output_uri):
if driver_output_uri:
PrintEqualsLine()
log.warning("Job attempt failed. Streaming new attempt's output.")
PrintEqualsLine()
driver_output_uri = job.driverOutputResourceUri
driver_log_stream = storage_helpers.StorageObjectSeriesStream(
job.driverOutputResourceUri)
time.sleep(log_poll_period_s)
now = time.time()
state = job.status.state
# goal_state and error_state will always be terminal
if state in dataproc.terminal_job_states:
if stream_driver_log:
if not driver_log_stream:
log.warning('Expected job output not found.')
elif driver_log_stream.open:
log.warning('Job terminated, but output did not finish streaming.')
if state is goal_state:
return job
if error_state and state is error_state:
if job.status.details:
raise exceptions.JobError('Job [{0}] failed with error:\n{1}'.format(
job_ref.jobId, job.status.details))
raise exceptions.JobError('Job [{0}] failed.'.format(job_ref.jobId))
if job.status.details:
log.info('Details:\n' + job.status.details)
raise exceptions.JobError(
'Job [{0}] entered state [{1}] while waiting for [{2}].'.format(
job_ref.jobId, state, goal_state))
raise exceptions.JobTimeoutError(
'Job [{0}] timed out while in state [{1}].'.format(job_ref.jobId, state))
# This replicates the fallthrough logic of flags._RegionAttributeConfig.
# It is necessary in cases like the --region flag where we are not parsing
# ResourceSpecs
def ResolveRegion():
return properties.VALUES.dataproc.region.GetOrFail()
# This replicates the fallthrough logic of flags._LocationAttributeConfig.
# It is necessary in cases like the --location flag where we are not parsing
# ResourceSpecs
def ResolveLocation():
return properties.VALUES.dataproc.location.GetOrFail()
# You probably want to use flags.AddClusterResourceArgument instead.
# If calling this method, you *must* have called flags.AddRegionFlag first to
# ensure a --region flag is stored into properties, which ResolveRegion
# depends on. This is also mutually incompatible with any usage of args.CONCEPTS
# which use --region as a resource attribute.
def ParseCluster(name, dataproc):
ref = dataproc.resources.Parse(
name,
params={
'region': ResolveRegion,
'projectId': properties.VALUES.core.project.GetOrFail
},
collection='dataproc.projects.regions.clusters')
return ref
# You probably want to use flags.AddJobResourceArgument instead.
# If calling this method, you *must* have called flags.AddRegionFlag first to
# ensure a --region flag is stored into properties, which ResolveRegion
# depends on. This is also mutually incompatible with any usage of args.CONCEPTS
# which use --region as a resource attribute.
def ParseJob(job_id, dataproc):
ref = dataproc.resources.Parse(
job_id,
params={
'region': ResolveRegion,
'projectId': properties.VALUES.core.project.GetOrFail
},
collection='dataproc.projects.regions.jobs')
return ref
def ParseOperationJsonMetadata(metadata_value, metadata_type):
"""Returns an Operation message for a metadata value."""
if not metadata_value:
return metadata_type()
return encoding.JsonToMessage(metadata_type,
encoding.MessageToJson(metadata_value))
# Used in bizarre scenarios where we want a qualified region rather than a
# short name
def ParseRegion(dataproc):
ref = dataproc.resources.Parse(
None,
params={
'regionId': ResolveRegion,
'projectId': properties.VALUES.core.project.GetOrFail
},
collection='dataproc.projects.regions')
return ref
# Get dataproc.projects.locations resource
def ParseProjectsLocations(dataproc):
ref = dataproc.resources.Parse(
None,
params={
'locationsId': ResolveRegion,
'projectsId': properties.VALUES.core.project.GetOrFail
},
collection='dataproc.projects.locations')
return ref
# Get dataproc.projects.locations resource
# This can be merged with ParseProjectsLocations() once we have migrated batches
# from `region` to `location`.
def ParseProjectsLocationsForSession(dataproc):
ref = dataproc.resources.Parse(
None,
params={
'locationsId': ResolveLocation(),
'projectsId': properties.VALUES.core.project.GetOrFail
},
collection='dataproc.projects.locations')
return ref
def ReadAutoscalingPolicy(dataproc, policy_id, policy_file_name=None):
"""Returns autoscaling policy read from YAML file.
Args:
dataproc: wrapper for dataproc resources, client and messages.
policy_id: The autoscaling policy id (last piece of the resource name).
policy_file_name: if set, location of the YAML file to read from. Otherwise,
reads from stdin.
Raises:
argparse.ArgumentError if duration formats are invalid or out of bounds.
"""
data = console_io.ReadFromFileOrStdin(policy_file_name or '-', binary=False)
policy = export_util.Import(
message_type=dataproc.messages.AutoscalingPolicy, stream=data)
# Ignore user set id in the file (if any), and overwrite with the policy_ref
# provided with this command
policy.id = policy_id
# Similarly, ignore the set resource name. This field is OUTPUT_ONLY, so we
# can just clear it.
policy.name = None
# Set duration fields to their seconds values
if policy.basicAlgorithm is not None:
if policy.basicAlgorithm.cooldownPeriod is not None:
policy.basicAlgorithm.cooldownPeriod = str(
arg_parsers.Duration(lower_bound='2m', upper_bound='1d')(
policy.basicAlgorithm.cooldownPeriod)) + 's'
if policy.basicAlgorithm.yarnConfig.gracefulDecommissionTimeout is not None:
policy.basicAlgorithm.yarnConfig.gracefulDecommissionTimeout = str(
arg_parsers.Duration(lower_bound='0s', upper_bound='1d')
(policy.basicAlgorithm.yarnConfig.gracefulDecommissionTimeout)) + 's'
return policy
def CreateAutoscalingPolicy(dataproc, name, policy):
"""Returns the server-resolved policy after creating the given policy.
Args:
dataproc: wrapper for dataproc resources, client and messages.
name: The autoscaling policy resource name.
policy: The AutoscalingPolicy message to create.
"""
# TODO(b/109837200) make the dataproc discovery doc parameters consistent
# Parent() fails for the collection because of projectId/projectsId and
# regionId/regionsId inconsistencies.
# parent = template_ref.Parent().RelativePath()
parent = '/'.join(name.split('/')[0:4])
request = \
dataproc.messages.DataprocProjectsRegionsAutoscalingPoliciesCreateRequest(
parent=parent,
autoscalingPolicy=policy)
policy = dataproc.client.projects_regions_autoscalingPolicies.Create(request)
log.status.Print('Created [{0}].'.format(policy.id))
return policy
def UpdateAutoscalingPolicy(dataproc, name, policy):
"""Returns the server-resolved policy after updating the given policy.
Args:
dataproc: wrapper for dataproc resources, client and messages.
name: The autoscaling policy resource name.
policy: The AutoscalingPolicy message to create.
"""
# Though the name field is OUTPUT_ONLY in the API, the Update() method of the
# gcloud generated dataproc client expects it to be set.
policy.name = name
policy = \
dataproc.client.projects_regions_autoscalingPolicies.Update(policy)
log.status.Print('Updated [{0}].'.format(policy.id))
return policy
def _DownscopeCredentials(token, access_boundary_json):
"""Downscope the given credentials to the given access boundary.
Args:
token: The credentials to downscope.
access_boundary_json: The JSON-formatted access boundary.
Returns:
A downscopded credential with the given access-boundary.
"""
payload = {
'grant_type': 'urn:ietf:params:oauth:grant-type:token-exchange',
'requested_token_type': 'urn:ietf:params:oauth:token-type:access_token',
'subject_token_type': 'urn:ietf:params:oauth:token-type:access_token',
'subject_token': token,
'options': access_boundary_json
}
universe_domain = properties.VALUES.core.universe_domain.Get()
cab_token_url = f'https://sts.{universe_domain}/v1/token'
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
downscope_response = requests.GetSession().post(
cab_token_url, headers=headers, data=payload)
if downscope_response.status_code != 200:
raise ValueError('Error downscoping credentials')
cab_token = json.loads(downscope_response.content)
return cab_token.get('access_token', None)
def GetCredentials(access_boundary_json):
"""Get an access token for the user's current credentials.
Args:
access_boundary_json: JSON string holding the definition of the access
boundary to apply to the credentials.
Raises:
PersonalAuthError: If no access token could be fetched for the user.
Returns:
An access token for the user.
"""
cred = c_store.Load(
None, allow_account_impersonation=True, use_google_auth=True)
c_store.Refresh(cred)
if c_creds.IsOauth2ClientCredentials(cred):
token = cred.access_token
else:
token = cred.token
if not token:
raise exceptions.PersonalAuthError(
'No access token could be obtained from the current credentials.')
return _DownscopeCredentials(token, access_boundary_json)
class PersonalAuthUtils(object):
"""Util functions for enabling personal auth session."""
def __init__(self):
pass
def _RunOpensslCommand(self, openssl_executable, args, stdin=None):
"""Run the specified command, capturing and returning output as appropriate.
Args:
openssl_executable: The path to the openssl executable.
args: The arguments to the openssl command to run.
stdin: The input to the command.
Returns:
The output of the command.
Raises:
PersonalAuthError: If the call to openssl fails
"""
command = [openssl_executable]
command.extend(args)
stderr = None
try:
if getattr(subprocess, 'run', None):
proc = subprocess.run(
command,
input=stdin,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=False)
stderr = proc.stderr.decode('utf-8').strip()
# N.B. It would be better if we could simply call `subprocess.run` with
# the `check` keyword arg set to true rather than manually calling
# `check_returncode`. However, we want to capture the stderr when the
# command fails, and the CalledProcessError type did not have a field
# for the stderr until Python version 3.5.
#
# As such, we need to manually call `check_returncode` as long as we
# are supporting Python versions prior to 3.5.
proc.check_returncode()
return proc.stdout
else:
p = subprocess.Popen(
command,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdout, _ = p.communicate(input=stdin)
return stdout
except Exception as ex:
if stderr:
log.error('OpenSSL command "%s" failed with error message "%s"',
' '.join(command), stderr)
raise exceptions.PersonalAuthError('Failure running openssl command: "' +
' '.join(command) + '": ' +
six.text_type(ex))
def _ComputeHmac(self, key, data, openssl_executable):
"""Compute HMAC tag using OpenSSL."""
cmd_output = self._RunOpensslCommand(
openssl_executable, ['dgst', '-sha256', '-hmac', key],
stdin=data).decode('utf-8')
try:
# Split the openssl output to get the HMAC.
stripped_output = cmd_output.strip().split(' ')[1]
if len(stripped_output) != 64:
raise ValueError('HMAC output is expected to be 64 characters long.')
int(stripped_output, 16) # Check that the HMAC is in hex format.
except Exception as ex:
raise exceptions.PersonalAuthError(
'Failure due to invalid openssl output: ' + six.text_type(ex))
return (stripped_output + '\n').encode('utf-8')
def _DeriveHkdfKey(self, prk, info, openssl_executable):
"""Derives HMAC-based Key Derivation Function (HKDF) key through expansion on the initial pseudorandom key.
Args:
prk: a pseudorandom key.
info: optional context and application specific information (can be
empty).
openssl_executable: The path to the openssl executable.
Returns:
Output keying material, expected to be of 256-bit length.
"""
if len(prk) != 32:
raise ValueError(
'The given initial pseudorandom key is expected to be 32 bytes long.')
base16_prk = base64.b16encode(prk).decode('utf-8')
t1 = self._ComputeHmac(base16_prk, b'', openssl_executable)
t2data = bytearray(t1)
t2data.extend(info)
t2data.extend(b'\x01')
return self._ComputeHmac(base16_prk, t2data, openssl_executable)
# It is possible (although very rare) for the random pad generated by
# openssl to not be usable by openssl for encrypting the secret. When
# that happens the call to openssl will raise a CalledProcessError with
# the message "Error reading password from BIO\nError getting password".
#
# To account for this we retry on that error, but this is so rare that
# a single retry should be sufficient.
@retry.RetryOnException(max_retrials=1)
def _EncodeTokenUsingOpenssl(self, public_key, secret, openssl_executable):
"""Encode token using OpenSSL.
Args:
public_key: The public key for the session/cluster.
secret: Token to be encrypted.
openssl_executable: The path to the openssl executable.
Returns:
Encrypted token.
"""
key_hash = hashlib.sha256((public_key + '\n').encode('utf-8')).hexdigest()
iv_bytes = base64.b16encode(os.urandom(16))
initialization_vector = iv_bytes.decode('utf-8')
initial_key = os.urandom(32)
encryption_key = self._DeriveHkdfKey(initial_key,
'encryption_key'.encode('utf-8'),
openssl_executable)
auth_key = base64.b16encode(
self._DeriveHkdfKey(initial_key, 'auth_key'.encode('utf-8'),
openssl_executable)).decode('utf-8')
with tempfile.NamedTemporaryFile() as kf:
kf.write(public_key.encode('utf-8'))
kf.seek(0)
encrypted_key = self._RunOpensslCommand(
openssl_executable,
['rsautl', '-oaep', '-encrypt', '-pubin', '-inkey', kf.name],
stdin=base64.b64encode(initial_key))
if len(encrypted_key) != 512:
raise ValueError('The encrypted key is expected to be 512 bytes long.')
encoded_key = base64.b64encode(encrypted_key).decode('utf-8')
with tempfile.NamedTemporaryFile() as pf:
pf.write(encryption_key)
pf.seek(0)
encrypt_args = [
'enc', '-aes-256-ctr', '-salt', '-iv', initialization_vector, '-pass',
'file:{}'.format(pf.name)
]
encrypted_token = self._RunOpensslCommand(
openssl_executable, encrypt_args, stdin=secret.encode('utf-8'))
if len(encrypted_key) != 512:
raise ValueError('The encrypted key is expected to be 512 bytes long.')
encoded_token = base64.b64encode(encrypted_token).decode('utf-8')
hmac_input = bytearray(iv_bytes)
hmac_input.extend(encrypted_token)
hmac_tag = self._ComputeHmac(auth_key, hmac_input,
openssl_executable).decode('utf-8')[
0:32] # Truncate the HMAC tag to 128-bit
return '{}:{}:{}:{}:{}'.format(key_hash, encoded_token, encoded_key,
initialization_vector, hmac_tag)
def EncryptWithPublicKey(self, public_key, secret, openssl_executable):
"""Encrypt secret with resource public key.
Args:
public_key: The public key for the session/cluster.
secret: Token to be encrypted.
openssl_executable: The path to the openssl executable.
Returns:
Encrypted token.
"""
if openssl_executable:
return self._EncodeTokenUsingOpenssl(public_key, secret,
openssl_executable)
try:
# pylint: disable=g-import-not-at-top
import tink
from tink import hybrid
# pylint: enable=g-import-not-at-top
except ImportError:
raise exceptions.PersonalAuthError(
'Cannot load the Tink cryptography library. Either the '
'library is not installed, or site packages are not '
'enabled for the Google Cloud SDK. Please consult Cloud '
'Dataproc Personal Auth documentation on adding Tink to '
'Google Cloud SDK for further instructions.\n'
'https://cloud.google.com/dataproc/docs/concepts/iam/personal-auth')
hybrid.register()
context = b''
# Extract value of key corresponding to primary key.
public_key_value = json.loads(public_key)['key'][0]['keyData']['value']
key_hash = hashlib.sha256(
(public_key_value + '\n').encode('utf-8')).hexdigest()
# Load public key and create keyset handle.
reader = tink.JsonKeysetReader(public_key)
kh_pub = tink.read_no_secret_keyset_handle(reader)
# Create encrypter instance.
encrypter = kh_pub.primitive(hybrid.HybridEncrypt)
ciphertext = encrypter.encrypt(secret.encode('utf-8'), context)
encoded_token = base64.b64encode(ciphertext).decode('utf-8')
return '{}:{}'.format(key_hash, encoded_token)
def IsTinkLibraryInstalled(self):
"""Check if Tink cryptography library can be loaded."""
try:
# pylint: disable=g-import-not-at-top
# pylint: disable=unused-import
import tink
from tink import hybrid
# pylint: enable=g-import-not-at-top
# pylint: enable=unused-import
return True
except ImportError:
return False
def ReadSessionTemplate(dataproc, template_file_name=None):
"""Returns session template read from YAML file.
Args:
dataproc: Wrapper for dataproc resources, client and messages.
template_file_name: If set, location of the YAML file to read from.
Otherwise, reads from stdin.
Raises:
argparse.ArgumentError if duration formats are invalid or out of bounds.
"""
data = console_io.ReadFromFileOrStdin(template_file_name or '-', binary=False)
template = export_util.Import(
message_type=dataproc.messages.SessionTemplate, stream=data)
return template
def CreateSessionTemplate(dataproc, name, template):
"""Returns the server-resolved template after creating the given template.
Args:
dataproc: Wrapper for dataproc resources, client and messages.
name: The session template resource name.
template: The SessionTemplate message to create.
"""
parent = '/'.join(name.split('/')[0:4])
template.name = name
request = (
dataproc.messages.DataprocProjectsLocationsSessionTemplatesCreateRequest(
parent=parent,
sessionTemplate=template))
template = dataproc.client.projects_locations_sessionTemplates.Create(request)
log.status.Print('Created [{0}].'.format(template.name))
return template
def UpdateSessionTemplate(dataproc, name, template):
"""Returns the server-resolved template after updating the given template.
Args:
dataproc: Wrapper for dataproc resources, client and messages.
name: The session template resource name.
template: The SessionTemplate message to create.
"""
template.name = name
template = dataproc.client.projects_locations_sessionTemplates.Patch(template)
log.status.Print('Updated [{0}].'.format(template.name))
return template
def YieldFromListWithUnreachableList(unreachable_warning_msg, *args, **kwargs):
"""Yields from paged List calls handling unreachable list."""
unreachable = set()
def _GetFieldFn(message, attr):
unreachable.update(message.unreachable)
return getattr(message, attr)
result = list_pager.YieldFromList(get_field_func=_GetFieldFn, *args, **kwargs)
for item in result:
yield item
if unreachable:
log.warning(
unreachable_warning_msg,
', '.join(sorted(unreachable)),
)