File: //snap/google-cloud-cli/current/lib/surface/dataflow/yaml/run.py
# -*- coding: utf-8 -*- #
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of gcloud dataflow yaml run command."""
from googlecloudsdk.api_lib.dataflow import apis
from googlecloudsdk.api_lib.storage import storage_api
from googlecloudsdk.api_lib.storage import storage_util
from googlecloudsdk.calliope import arg_parsers
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.dataflow import dataflow_util
from googlecloudsdk.core import properties
from googlecloudsdk.core import yaml
from googlecloudsdk.core.util import files
@base.DefaultUniverseOnly
@base.ReleaseTracks(base.ReleaseTrack.GA, base.ReleaseTrack.BETA)
class Run(base.Command):
"""Runs a job from the specified path."""
detailed_help = {
'DESCRIPTION': (
'Runs a job from the specified YAML description or '
'Cloud Storage path.'
),
'EXAMPLES': """\
To run a job from YAML, run:
$ {command} my-job --yaml-pipeline-file=gs://yaml-path --region=europe-west1
""",
}
@staticmethod
def Args(parser):
"""Register flags for this command.
Args:
parser: argparse.ArgumentParser to register arguments with.
"""
parser.add_argument(
'job_name', metavar='JOB_NAME', help='Unique name to assign to the job.'
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
'--yaml-pipeline-file',
help=(
'Path of a file defining the YAML pipeline to run. '
"(Must be a local file or a URL beginning with 'gs://'.)"
),
)
group.add_argument(
'--yaml-pipeline', help='Inline definition of the YAML pipeline to run.'
)
parser.add_argument(
'--region',
metavar='REGION_ID',
help=(
"Region ID of the job's regional endpoint. "
+ dataflow_util.DEFAULT_REGION_MESSAGE
),
)
parser.add_argument(
'--pipeline-options',
metavar='OPTIONS=VALUE;OPTION=VALUE',
type=arg_parsers.ArgDict(),
action=arg_parsers.UpdateAction,
help='Pipeline options to pass to the job.',
)
parser.add_argument(
'--jinja-variables',
metavar='JSON_OBJECT',
help='Jinja2 variables to be used in reifying the yaml.',
)
parser.add_argument(
'--template-file-gcs-location',
help=('Google Cloud Storage location of the YAML template to run. '
"(Must be a URL beginning with 'gs://'.)"),
type=arg_parsers.RegexpValidator(r'^gs://.*',
'Must begin with \'gs://\''),
)
parser.add_argument(
'--network',
help=(
'Compute Engine network for launching worker instances to run '
'the pipeline. If not set, the default network is used.'
),
)
parser.add_argument(
'--subnetwork',
help=(
'Compute Engine subnetwork for launching worker instances to '
'run the pipeline. If not set, the default subnetwork is used.'
),
)
def Run(self, args):
"""Runs the command.
Args:
args: The arguments that were provided to this command invocation.
Returns:
A Job message.
"""
parameters = dict(args.pipeline_options or {})
# These are required and mutually exclusive due to the grouping above.
if args.yaml_pipeline_file:
yaml_contents = _try_get_yaml_contents(args.yaml_pipeline_file)
if yaml_contents is None:
parameters['yaml_pipeline_file'] = args.yaml_pipeline_file
else:
parameters['yaml_pipeline'] = yaml_contents
else:
parameters['yaml_pipeline'] = args.yaml_pipeline
if args.jinja_variables:
parameters['jinja_variables'] = args.jinja_variables
if 'yaml_pipeline' in parameters and 'jinja-variables' not in parameters:
_validate_yaml(parameters['yaml_pipeline'])
region_id = _get_region_from_yaml_or_default(
parameters.get('yaml_pipeline'), args
)
gcs_location = (
args.template_file_gcs_location
or apis.Templates.YAML_TEMPLATE_GCS_LOCATION.format(region_id)
)
arguments = apis.TemplateArguments(
project_id=properties.VALUES.core.project.Get(required=True),
region_id=region_id,
job_name=args.job_name,
gcs_location=gcs_location,
parameters=parameters,
network=args.network,
subnetwork=args.subnetwork,
)
return apis.Templates.CreateJobFromFlexTemplate(arguments)
def _validate_yaml(yaml_pipeline):
# TODO(b/320740846): Do more complete validation without requiring importing
# the entire beam library.
try:
_ = yaml.load(yaml_pipeline)
except Exception as exn:
raise ValueError('yaml_pipeline must be a valid yaml.') from exn
def _get_region_from_yaml_or_default(yaml_pipeline, args):
"""Gets the region from yaml pipeline or args, or falls back to default."""
region = args.region
options_region = None
try:
pipeline_data = yaml.load(yaml_pipeline)
if not pipeline_data:
return dataflow_util.GetRegion(args)
if 'options' in pipeline_data and 'region' in pipeline_data['options']:
options_region = pipeline_data['options']['region']
if '{' in options_region or '}' in options_region:
raise yaml.YAMLParseError(
'yaml pipeline contains unparsable region: {0}. Found curly braces '
'in region. Falling back to default region.'.format(options_region)
)
except yaml.YAMLParseError as exn:
if not region:
print(
'Failed to get region from yaml pipeline: {0}. If using jinja '
'variables, parsing may fail. Falling back to default '
'region.'.format(exn)
)
if options_region:
if region and region != options_region:
raise ValueError(
'Region specified in yaml pipeline options ({0}) does not match'
' region specified in command line ({1})'.format(
options_region, region
)
)
return options_region
return dataflow_util.GetRegion(args)
def _try_get_yaml_contents(yaml_pipeline_file):
"""Reads yaml contents from the specified file if it is accessable."""
if not yaml_pipeline_file.startswith('gs://'):
return files.ReadFileContents(yaml_pipeline_file)
storage_client = storage_api.StorageClient()
obj_ref = storage_util.ObjectReference.FromUrl(yaml_pipeline_file)
try:
return storage_client.ReadObject(obj_ref).read().decode('utf-8')
except Exception as e: # pylint: disable=broad-exception-caught
print(
'Unable to read file {0} due to incorrect file path or insufficient'
' read permissions. Will not be able to validate the yaml pipeline or'
' determine the region from the yaml pipeline'
' options. Error: {1}'.format(yaml_pipeline_file, e)
)
return None