File: //snap/google-cloud-cli/396/lib/googlecloudsdk/api_lib/datapipelines/util.py
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data Pipelines API utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
import six
_DEFAULT_API_VERSION = 'v1'
def GetMessagesModule(api_version=_DEFAULT_API_VERSION):
  return apis.GetMessagesModule('datapipelines', api_version)
def GetClientInstance(api_version=_DEFAULT_API_VERSION):
  return apis.GetClientInstance('datapipelines', api_version)
def GetPipelineURI(resource):
  pipeline = resources.REGISTRY.ParseRelativeName(
      resource.name, collection='datapipelines.pipelines')
  return pipeline.SelfLink()
def GetJobURI(resource):
  job = resources.REGISTRY.ParseRelativeName(
      resource.name, collection='datapipelines.pipelines.jobs')
  return job.SelfLink()
class PipelinesClient(object):
  """Client for Pipelines for the Data Pipelines API."""
  def __init__(self, client=None, messages=None):
    self.client = client or GetClientInstance()
    self.messages = messages or GetMessagesModule()
    self._service = self.client.projects_locations_pipelines
  def Describe(self, pipeline):
    """Describe a Pipeline in the given project and region.
    Args:
      pipeline: str, the name for the Pipeline being described.
    Returns:
      Described Pipeline Resource.
    """
    describe_req = self.messages.DatapipelinesProjectsLocationsPipelinesGetRequest(
        name=pipeline)
    return self._service.Get(describe_req)
  def Delete(self, pipeline):
    """Delete a Pipeline in the given project and region.
    Args:
      pipeline: str, the name for the Pipeline being described.
    Returns:
      Empty Response.
    """
    delete_req = self.messages.DatapipelinesProjectsLocationsPipelinesDeleteRequest(
        name=pipeline)
    return self._service.Delete(delete_req)
  def Stop(self, pipeline):
    """Stop a Pipeline in the given project and region.
    Args:
      pipeline: str, the name for the Pipeline being described.
    Returns:
      Pipeline resource.
    """
    stop_req = self.messages.DatapipelinesProjectsLocationsPipelinesStopRequest(
        name=pipeline)
    return self._service.Stop(stop_req)
  def Run(self, pipeline):
    """Run a Pipeline in the given project and region.
    Args:
      pipeline: str, the name for the Pipeline being described.
    Returns:
      Job resource which was created.
    """
    stop_req = self.messages.DatapipelinesProjectsLocationsPipelinesRunRequest(
        name=pipeline)
    return self._service.Run(stop_req)
  def List(self, limit=None, page_size=50, input_filter='', region=''):
    """List Pipelines for the given project and region.
    Args:
      limit: int or None, the total number of results to return.
      page_size: int, the number of entries in each batch (affects requests
        made, but not the yielded results).
      input_filter: string, optional filter to pass, eg:
        "type:BATCH,status:ALL", to filter out the pipelines based on staus or
        type.
      region: string, relative name to the region.
    Returns:
      Generator of matching devices.
    """
    list_req = self.messages.DatapipelinesProjectsLocationsPipelinesListRequest(
        filter=input_filter, parent=region)
    return list_pager.YieldFromList(
        self.client.projects_locations_pipelines,
        list_req,
        field='pipelines',
        method='List',
        batch_size=page_size,
        limit=limit,
        batch_size_attribute='pageSize')
  def CreateLegacyTemplateRequest(self, args):
    """Create a Legacy Template request for the Pipeline workload.
    Args:
      args: Any, list of args needed to create a Pipeline.
    Returns:
      Legacy Template request.
    """
    location = args.region
    project_id = properties.VALUES.core.project.Get(required=True)
    params_list = self.ConvertDictArguments(
        args.parameters, self.messages
        .GoogleCloudDatapipelinesV1LaunchTemplateParameters.ParametersValue)
    transform_mapping_list = self.ConvertDictArguments(
        args.transform_name_mappings,
        self.messages.GoogleCloudDatapipelinesV1LaunchTemplateParameters
        .TransformNameMappingValue)
    transform_name_mappings = None
    if transform_mapping_list:
      transform_name_mappings = self.messages.GoogleCloudDatapipelinesV1LaunchTemplateParameters.TransformNameMappingValue(
          additionalProperties=transform_mapping_list)
    ip_private = self.messages.GoogleCloudDatapipelinesV1RuntimeEnvironment.IpConfigurationValueValuesEnum.WORKER_IP_PRIVATE
    ip_configuration = ip_private if args.disable_public_ips else None
    user_labels_list = self.ConvertDictArguments(
        args.additional_user_labels, self.messages
        .GoogleCloudDatapipelinesV1RuntimeEnvironment.AdditionalUserLabelsValue)
    additional_user_labels = None
    if user_labels_list:
      additional_user_labels = self.messages.GoogleCloudDatapipelinesV1RuntimeEnvironment.AdditionalUserLabelsValue(
          additionalProperties=user_labels_list)
    launch_parameter = self.messages.GoogleCloudDatapipelinesV1LaunchTemplateParameters(
        environment=self.messages.GoogleCloudDatapipelinesV1RuntimeEnvironment(
            serviceAccountEmail=args.dataflow_service_account_email,
            maxWorkers=args.max_workers,
            numWorkers=args.num_workers,
            network=args.network,
            subnetwork=args.subnetwork,
            machineType=args.worker_machine_type,
            tempLocation=args.temp_location,
            kmsKeyName=args.dataflow_kms_key,
            ipConfiguration=ip_configuration,
            workerRegion=args.worker_region,
            workerZone=args.worker_zone,
            enableStreamingEngine=args.enable_streaming_engine,
            additionalExperiments=(args.additional_experiments
                                   if args.additional_experiments else []),
            additionalUserLabels=additional_user_labels),
        update=args.update,
        parameters=self.messages
        .GoogleCloudDatapipelinesV1LaunchTemplateParameters.ParametersValue(
            additionalProperties=params_list) if params_list else None,
        transformNameMapping=transform_name_mappings)
    return self.messages.GoogleCloudDatapipelinesV1LaunchTemplateRequest(
        gcsPath=args.template_file_gcs_location,
        location=location,
        projectId=project_id,
        launchParameters=launch_parameter)
  def CreateFlexTemplateRequest(self, args):
    """Create a Flex Template request for the Pipeline workload.
    Args:
      args: Any, list of args needed to create a Pipeline.
    Returns:
      Flex Template request.
    """
    location = args.region
    project_id = properties.VALUES.core.project.Get(required=True)
    params_list = self.ConvertDictArguments(
        args.parameters, self.messages
        .GoogleCloudDatapipelinesV1LaunchFlexTemplateParameter.ParametersValue)
    transform_mapping_list = self.ConvertDictArguments(
        args.transform_name_mappings,
        self.messages.GoogleCloudDatapipelinesV1LaunchFlexTemplateParameter
        .TransformNameMappingsValue)
    transform_name_mappings = None
    if transform_mapping_list:
      transform_name_mappings = self.messages.GoogleCloudDatapipelinesV1LaunchFlexTemplateParameter.TransformNameMappingsValue(
          additionalProperties=transform_mapping_list)
    ip_private = self.messages.GoogleCloudDatapipelinesV1FlexTemplateRuntimeEnvironment.IpConfigurationValueValuesEnum.WORKER_IP_PRIVATE
    ip_configuration = ip_private if args.disable_public_ips else None
    user_labels_list = self.ConvertDictArguments(
        args.additional_user_labels,
        self.messages.GoogleCloudDatapipelinesV1FlexTemplateRuntimeEnvironment
        .AdditionalUserLabelsValue)
    additional_user_labels = None
    if user_labels_list:
      additional_user_labels = self.messages.GoogleCloudDatapipelinesV1FlexTemplateRuntimeEnvironment.AdditionalUserLabelsValue(
          additionalProperties=user_labels_list)
    flexrs_goal = None
    if args.flexrs_goal:
      if args.flexrs_goal == 'SPEED_OPTIMIZED':
        flexrs_goal = self.messages.GoogleCloudDatapipelinesV1FlexTemplateRuntimeEnvironment.FlexrsGoalValueValuesEnum.FLEXRS_SPEED_OPTIMIZED
      elif args.flexrs_goal == 'COST_OPTIMIZED':
        flexrs_goal = self.messages.GoogleCloudDatapipelinesV1FlexTemplateRuntimeEnvironment.FlexrsGoalValueValuesEnum.FLEXRS_COST_OPTIMIZED
    launch_parameter = self.messages.GoogleCloudDatapipelinesV1LaunchFlexTemplateParameter(
        containerSpecGcsPath=args.template_file_gcs_location,
        environment=self.messages
        .GoogleCloudDatapipelinesV1FlexTemplateRuntimeEnvironment(
            serviceAccountEmail=args.dataflow_service_account_email,
            maxWorkers=args.max_workers,
            numWorkers=args.num_workers,
            network=args.network,
            subnetwork=args.subnetwork,
            machineType=args.worker_machine_type,
            tempLocation=args.temp_location,
            kmsKeyName=args.dataflow_kms_key,
            ipConfiguration=ip_configuration,
            workerRegion=args.worker_region,
            workerZone=args.worker_zone,
            enableStreamingEngine=args.enable_streaming_engine,
            flexrsGoal=flexrs_goal,
            additionalExperiments=(args.additional_experiments
                                   if args.additional_experiments else []),
            additionalUserLabels=additional_user_labels),
        update=args.update,
        parameters=self.messages
        .GoogleCloudDatapipelinesV1LaunchFlexTemplateParameter.ParametersValue(
            additionalProperties=params_list) if params_list else None,
        transformNameMappings=transform_name_mappings)
    return self.messages.GoogleCloudDatapipelinesV1LaunchFlexTemplateRequest(
        location=location,
        projectId=project_id,
        launchParameter=launch_parameter)
  def Create(self, pipeline, parent, args):
    """Create a Pipeline in the given project and region.
    Args:
      pipeline: str, the name for the Pipeline being created.
      parent: str, relative name to the region.
      args: Any, list of args needed to create a Pipeline.
    Returns:
      Pipeline resource.
    """
    if args.pipeline_type == 'streaming':
      pipeline_type = self.messages.GoogleCloudDatapipelinesV1Pipeline.TypeValueValuesEnum(
          self.messages.GoogleCloudDatapipelinesV1Pipeline.TypeValueValuesEnum
          .PIPELINE_TYPE_STREAMING)
    else:
      pipeline_type = self.messages.GoogleCloudDatapipelinesV1Pipeline.TypeValueValuesEnum(
          self.messages.GoogleCloudDatapipelinesV1Pipeline.TypeValueValuesEnum
          .PIPELINE_TYPE_BATCH)
    schedule_info = self.messages.GoogleCloudDatapipelinesV1ScheduleSpec(
        schedule=args.schedule, timeZone=args.time_zone)
    if args.template_type == 'classic':
      legacy_template_request = self.CreateLegacyTemplateRequest(args)
      workload = self.messages.GoogleCloudDatapipelinesV1Workload(
          dataflowLaunchTemplateRequest=legacy_template_request)
    else:
      flex_template_request = self.CreateFlexTemplateRequest(args)
      workload = self.messages.GoogleCloudDatapipelinesV1Workload(
          dataflowFlexTemplateRequest=flex_template_request)
    if args.display_name:
      display_name = args.display_name
    else:
      display_name = pipeline.rsplit('/', 1)[-1]
    pipeline_spec = self.messages.GoogleCloudDatapipelinesV1Pipeline(
        name=pipeline,
        displayName=display_name,
        type=pipeline_type,
        scheduleInfo=schedule_info,
        schedulerServiceAccountEmail=args.scheduler_service_account_email,
        workload=workload)
    create_req = self.messages.DatapipelinesProjectsLocationsPipelinesCreateRequest(
        googleCloudDatapipelinesV1Pipeline=pipeline_spec, parent=parent)
    return self._service.Create(create_req)
  def WorkloadUpdateMask(self, template_type, args):
    """Given a set of args for the workload, create the required update mask.
    Args:
      template_type: str, the type of the pipeline.
      args: Any, object with args needed for updating a pipeline.
    Returns:
      Update mask.
    """
    update_mask = []
    if template_type == 'flex':
      prefix_string = 'workload.dataflow_flex_template_request.launch_parameter.'
    else:
      prefix_string = 'workload.dataflow_launch_template_request.launch_parameters.'
    if args.template_file_gcs_location:
      if template_type == 'flex':
        update_mask.append(prefix_string + 'container_spec_gcs_path')
      else:
        update_mask.append('workload.dataflow_launch_template_request.gcs_path')
    if args.parameters:
      update_mask.append(prefix_string + 'parameters')
    if args.update:
      update_mask.append(prefix_string + 'update')
    if args.transform_name_mappings:
      if template_type == 'flex':
        update_mask.append(prefix_string + 'transform_name_mappings')
      else:
        update_mask.append(prefix_string + 'transform_name_mapping')
    if args.max_workers:
      update_mask.append(prefix_string + 'environment.max_workers')
    if args.num_workers:
      update_mask.append(prefix_string + 'environment.num_workers')
    if args.dataflow_service_account_email:
      update_mask.append(prefix_string + 'environment.service_account_email')
    if args.temp_location:
      update_mask.append(prefix_string + 'environment.temp_location')
    if args.network:
      update_mask.append(prefix_string + 'environment.network')
    if args.subnetwork:
      update_mask.append(prefix_string + 'environment.subnetwork')
    if args.worker_machine_type:
      update_mask.append(prefix_string + 'environment.machine_type')
    if args.dataflow_kms_key:
      update_mask.append(prefix_string + 'environment.kms_key_name')
    if args.disable_public_ips:
      update_mask.append(prefix_string + 'environment.ip_configuration')
    if args.worker_region:
      update_mask.append(prefix_string + 'environment.worker_region')
    if args.worker_zone:
      update_mask.append(prefix_string + 'environment.worker_zone')
    if args.enable_streaming_engine:
      update_mask.append(prefix_string + 'environment.enable_streaming_engine')
    if args.flexrs_goal:
      if template_type == 'flex':
        update_mask.append(prefix_string + 'environment.flexrs_goal')
    if args.additional_user_labels:
      update_mask.append(prefix_string + 'environment.additional_user_labels')
    if args.additional_experiments:
      update_mask.append(prefix_string + 'environment.additional_experiments')
    return update_mask
  def Patch(self, pipeline, args):
    """Update a Pipeline in the given project and region.
    Args:
      pipeline: str, the name for the Pipeline being updated.
      args: Any, object with args needed to update a Pipeline.
    Returns:
      Pipeline resource.
    """
    update_mask = []
    schedule_info = None
    if args.schedule or args.time_zone:
      schedule, time_zone = None, None
      if args.schedule:
        schedule = args.schedule
        update_mask.append('schedule_info.schedule')
      if args.time_zone:
        time_zone = args.time_zone
        update_mask.append('schedule_info.time_zone')
      schedule_info = self.messages.GoogleCloudDatapipelinesV1ScheduleSpec(
          schedule=schedule, timeZone=time_zone)
    if args.display_name:
      update_mask.append('display_name')
    if args.template_type == 'classic':
      update_mask += self.WorkloadUpdateMask('classic', args)
      legacy_template_request = self.CreateLegacyTemplateRequest(args)
      workload = self.messages.GoogleCloudDatapipelinesV1Workload(
          dataflowLaunchTemplateRequest=legacy_template_request)
    else:
      update_mask += self.WorkloadUpdateMask('flex', args)
      flex_template_request = self.CreateFlexTemplateRequest(args)
      workload = self.messages.GoogleCloudDatapipelinesV1Workload(
          dataflowFlexTemplateRequest=flex_template_request)
    pipeline_spec = self.messages.GoogleCloudDatapipelinesV1Pipeline(
        name=pipeline,
        displayName=args.display_name,
        scheduleInfo=schedule_info,
        schedulerServiceAccountEmail=args.scheduler_service_account_email,
        workload=workload)
    update_req = self.messages.DatapipelinesProjectsLocationsPipelinesPatchRequest(
        googleCloudDatapipelinesV1Pipeline=pipeline_spec,
        name=pipeline,
        updateMask=','.join(update_mask))
    return self._service.Patch(update_req)
  def ConvertDictArguments(self, arguments, value_message):
    """Convert dictionary arguments to parameter list .
    Args:
      arguments: Arguments for create job using template.
      value_message: the value message of the arguments
    Returns:
      List of value_message.AdditionalProperty
    """
    params_list = []
    if arguments:
      for k, v in six.iteritems(arguments):
        params_list.append(value_message.AdditionalProperty(key=k, value=v))
    return params_list
class JobsClient(object):
  """Client used for interacting with job related service from the Data Pipelines API."""
  def __init__(self, client=None, messages=None):
    self.client = client or GetClientInstance()
    self.messages = messages or GetMessagesModule()
    self._service = self.client.projects_locations_pipelines_jobs
  def List(self, limit=None, page_size=50, pipeline=''):
    """Make API calls to list jobs for pipelines.
    Args:
      limit: int or None, the total number of results to return.
      page_size: int, the number of entries in each batch (affects requests
        made, but not the yielded results).
      pipeline: string, the name of the pipeline to list jobs for.
    Returns:
      Generator that yields jobs.
    """
    list_req = self.messages.DatapipelinesProjectsLocationsPipelinesJobsListRequest(
        parent=pipeline)
    return list_pager.YieldFromList(
        self._service,
        list_req,
        field='jobs',
        method='List',
        batch_size=page_size,
        limit=limit,
        batch_size_attribute='pageSize')