HEX
Server: Apache/2.4.65 (Ubuntu)
System: Linux ielts-store-v2 6.8.0-1036-gcp #38~22.04.1-Ubuntu SMP Thu Aug 14 01:19:18 UTC 2025 x86_64
User: root (0)
PHP: 7.2.34-54+ubuntu20.04.1+deb.sury.org+1
Disabled: pcntl_alarm,pcntl_fork,pcntl_waitpid,pcntl_wait,pcntl_wifexited,pcntl_wifstopped,pcntl_wifsignaled,pcntl_wifcontinued,pcntl_wexitstatus,pcntl_wtermsig,pcntl_wstopsig,pcntl_signal,pcntl_signal_get_handler,pcntl_signal_dispatch,pcntl_get_last_error,pcntl_strerror,pcntl_sigprocmask,pcntl_sigwaitinfo,pcntl_sigtimedwait,pcntl_exec,pcntl_getpriority,pcntl_setpriority,pcntl_async_signals,
Upload Files
File: //snap/google-cloud-cli/394/lib/googlecloudsdk/api_lib/ml_engine/jobs.py
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with ML jobs API."""

from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals

from apitools.base.py import encoding
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.util.apis import arg_utils
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
from googlecloudsdk.core import yaml


class NoFieldsSpecifiedError(exceptions.Error):
  """Error indicating that no updates were requested in a Patch operation."""


class NoPackagesSpecifiedError(exceptions.Error):
  """Error that no packages were specified for non-custom training."""


def GetMessagesModule(version='v1'):
  return apis.GetMessagesModule('ml', version)


def GetClientInstance(version='v1', no_http=False):
  return apis.GetClientInstance('ml', version, no_http=no_http)


class JobsClient(object):
  """Client for jobs service in the Cloud ML Engine API."""

  def __init__(self, client=None, messages=None,
               short_message_prefix='GoogleCloudMlV1', client_version='v1'):
    self.client = client or GetClientInstance(client_version)
    self.messages = messages or self.client.MESSAGES_MODULE
    self._short_message_prefix = short_message_prefix

  def GetShortMessage(self, short_message_name):
    return getattr(self.messages,
                   '{prefix}{name}'.format(prefix=self._short_message_prefix,
                                           name=short_message_name), None)

  @property
  def state_enum(self):
    return self.messages.GoogleCloudMlV1Job.StateValueValuesEnum

  def List(self, project_ref):
    req = self.messages.MlProjectsJobsListRequest(
        parent=project_ref.RelativeName())
    return list_pager.YieldFromList(
        self.client.projects_jobs, req, field='jobs',
        batch_size_attribute='pageSize')

  @property
  def job_class(self):
    return self.messages.GoogleCloudMlV1Job

  @property
  def training_input_class(self):
    return self.messages.GoogleCloudMlV1TrainingInput

  @property
  def prediction_input_class(self):
    return self.messages.GoogleCloudMlV1PredictionInput

  def _MakeCreateRequest(self, parent=None, job=None):
    return self.messages.MlProjectsJobsCreateRequest(
        parent=parent,
        googleCloudMlV1Job=job)

  def Create(self, project_ref, job):
    return self.client.projects_jobs.Create(
        self._MakeCreateRequest(
            parent=project_ref.RelativeName(),
            job=job))

  def Cancel(self, job_ref):
    """Cancels given job."""
    req = self.messages.MlProjectsJobsCancelRequest(name=job_ref.RelativeName())
    return self.client.projects_jobs.Cancel(req)

  def Get(self, job_ref):
    req = self.messages.MlProjectsJobsGetRequest(name=job_ref.RelativeName())
    return self.client.projects_jobs.Get(req)

  def Patch(self, job_ref, labels_update):
    """Update a job."""
    job = self.job_class()
    update_mask = []
    if labels_update.needs_update:
      job.labels = labels_update.labels
      update_mask.append('labels')
    if not update_mask:
      raise NoFieldsSpecifiedError('No updates requested.')
    req = self.messages.MlProjectsJobsPatchRequest(
        name=job_ref.RelativeName(),
        googleCloudMlV1Job=job,
        updateMask=','.join(update_mask)
    )
    return self.client.projects_jobs.Patch(req)

  def BuildTrainingJob(self,
                       path=None,
                       module_name=None,
                       job_name=None,
                       trainer_uri=None,
                       region=None,
                       job_dir=None,
                       scale_tier=None,
                       user_args=None,
                       runtime_version=None,
                       python_version=None,
                       network=None,
                       service_account=None,
                       labels=None,
                       kms_key=None,
                       custom_train_server_config=None,
                       enable_web_access=None):
    """Builds a Cloud ML Engine Job from a config file and/or flag values.

    Args:
        path: path to a yaml configuration file
        module_name: value to set for moduleName field (overrides yaml file)
        job_name: value to set for jobName field (overrides yaml file)
        trainer_uri: List of values to set for trainerUri field (overrides yaml
          file)
        region: compute region in which to run the job (overrides yaml file)
        job_dir: Cloud Storage working directory for the job (overrides yaml
          file)
        scale_tier: ScaleTierValueValuesEnum the scale tier for the job
          (overrides yaml file)
        user_args: [str]. A list of arguments to pass through to the job.
        (overrides yaml file)
        runtime_version: the runtime version in which to run the job (overrides
          yaml file)
        python_version: the Python version in which to run the job (overrides
          yaml file)
        network: user network to which the job should be peered with (overrides
          yaml file)
        service_account: A service account (email address string) to use for the
          job.
        labels: Job.LabelsValue, the Cloud labels for the job
        kms_key: A customer-managed encryption key to use for the job.
        custom_train_server_config: jobs_util.CustomTrainingInputServerConfig,
          configuration object for custom server parameters.
        enable_web_access: whether to enable the interactive shell for the job.
    Raises:
      NoPackagesSpecifiedError: if a non-custom job was specified without any
        trainer_uris.
    Returns:
        A constructed Job object.
    """
    job = self.job_class()

    # TODO(b/123467089): Remove yaml file loading here, only parse data objects
    if path:
      data = yaml.load_path(path)
      if data:
        job = encoding.DictToMessage(data, self.job_class)

    if job_name:
      job.jobId = job_name

    if labels is not None:
      job.labels = labels

    if not job.trainingInput:
      job.trainingInput = self.training_input_class()
    additional_fields = {
        'pythonModule': module_name,
        'args': user_args,
        'packageUris': trainer_uri,
        'region': region,
        'jobDir': job_dir,
        'scaleTier': scale_tier,
        'runtimeVersion': runtime_version,
        'pythonVersion': python_version,
        'network': network,
        'serviceAccount': service_account,
        'enableWebAccess': enable_web_access,
    }
    for field_name, value in additional_fields.items():
      if value is not None:
        setattr(job.trainingInput, field_name, value)

    if kms_key:
      arg_utils.SetFieldInMessage(job,
                                  'trainingInput.encryptionConfig.kmsKeyName',
                                  kms_key)

    if custom_train_server_config:
      for field_name, value in custom_train_server_config.GetFieldMap().items():
        if value is not None:
          if (field_name.endswith('Config') and
              not field_name.endswith('TfConfig')):
            if value['imageUri']:
              arg_utils.SetFieldInMessage(
                  job,
                  'trainingInput.{}.imageUri'.format(field_name),
                  value['imageUri'])
            if value['acceleratorConfig']['type']:
              arg_utils.SetFieldInMessage(
                  job,
                  'trainingInput.{}.acceleratorConfig.type'.format(field_name),
                  value['acceleratorConfig']['type'])
            if value['acceleratorConfig']['count']:
              arg_utils.SetFieldInMessage(
                  job,
                  'trainingInput.{}.acceleratorConfig.count'.format(field_name),
                  value['acceleratorConfig']['count'])
            if field_name == 'workerConfig' and value['tpuTfVersion']:
              arg_utils.SetFieldInMessage(
                  job,
                  'trainingInput.{}.tpuTfVersion'.format(field_name),
                  value['tpuTfVersion'])
          else:
            setattr(job.trainingInput, field_name, value)

    if not self.HasPackageURIs(job) and not self.IsCustomContainerTraining(job):
      raise NoPackagesSpecifiedError('Non-custom jobs must have packages.')

    return job

  def HasPackageURIs(self, job):
    return bool(job.trainingInput.packageUris)

  def IsCustomContainerTraining(self, job):
    return bool(job.trainingInput.masterConfig and
                job.trainingInput.masterConfig.imageUri)

  def BuildBatchPredictionJob(self,
                              job_name=None,
                              model_dir=None,
                              model_name=None,
                              version_name=None,
                              input_paths=None,
                              data_format=None,
                              output_path=None,
                              region=None,
                              runtime_version=None,
                              max_worker_count=None,
                              batch_size=None,
                              signature_name=None,
                              labels=None,
                              accelerator_count=None,
                              accelerator_type=None):
    """Builds a Cloud ML Engine Job for batch prediction from flag values.

    Args:
        job_name: value to set for jobName field
        model_dir: str, Google Cloud Storage location of the model files
        model_name: str, value to set for modelName field
        version_name: str, value to set for versionName field
        input_paths: list of input files
        data_format: format of the input files
        output_path: single value for the output location
        region: compute region in which to run the job
        runtime_version: the runtime version in which to run the job
        max_worker_count: int, the maximum number of workers to use
        batch_size: int, the number of records per batch sent to Tensorflow
        signature_name: str, name of input/output signature in the TF meta graph
        labels: Job.LabelsValue, the Cloud labels for the job
        accelerator_count: int, The number of accelerators to attach to the
           machines
       accelerator_type: AcceleratorsValueListEntryValuesEnum, The type of
           accelerator to add to machine.
    Returns:
        A constructed Job object.
    """
    project_id = properties.VALUES.core.project.GetOrFail()

    if accelerator_type:
      accelerator_config_msg = self.GetShortMessage('AcceleratorConfig')
      accelerator_config = accelerator_config_msg(count=accelerator_count,
                                                  type=accelerator_type)
    else:
      accelerator_config = None

    prediction_input = self.prediction_input_class(
        inputPaths=input_paths,
        outputPath=output_path,
        region=region,
        runtimeVersion=runtime_version,
        maxWorkerCount=max_worker_count,
        batchSize=batch_size,
        accelerator=accelerator_config
    )

    prediction_input.dataFormat = prediction_input.DataFormatValueValuesEnum(
        data_format)
    if model_dir:
      prediction_input.uri = model_dir
    elif version_name:
      version_ref = resources.REGISTRY.Parse(
          version_name, collection='ml.projects.models.versions',
          params={'modelsId': model_name, 'projectsId': project_id})
      prediction_input.versionName = version_ref.RelativeName()
    else:
      model_ref = resources.REGISTRY.Parse(
          model_name, collection='ml.projects.models',
          params={'projectsId': project_id})
      prediction_input.modelName = model_ref.RelativeName()

    if signature_name:
      prediction_input.signatureName = signature_name

    return self.job_class(
        jobId=job_name,
        predictionInput=prediction_input,
        labels=labels
    )