HEX
Server: Apache/2.4.65 (Ubuntu)
System: Linux ielts-store-v2 6.8.0-1036-gcp #38~22.04.1-Ubuntu SMP Thu Aug 14 01:19:18 UTC 2025 x86_64
User: root (0)
PHP: 7.2.34-54+ubuntu20.04.1+deb.sury.org+1
Disabled: pcntl_alarm,pcntl_fork,pcntl_waitpid,pcntl_wait,pcntl_wifexited,pcntl_wifstopped,pcntl_wifsignaled,pcntl_wifcontinued,pcntl_wexitstatus,pcntl_wtermsig,pcntl_wstopsig,pcntl_signal,pcntl_signal_get_handler,pcntl_signal_dispatch,pcntl_get_last_error,pcntl_strerror,pcntl_sigprocmask,pcntl_sigwaitinfo,pcntl_sigtimedwait,pcntl_exec,pcntl_getpriority,pcntl_setpriority,pcntl_async_signals,
Upload Files
File: //snap/google-cloud-cli/current/lib/surface/ai_platform/jobs/submit/prediction.py
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ai-platform jobs submit batch prediction command."""

from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals

from googlecloudsdk.api_lib.ml_engine import jobs
from googlecloudsdk.calliope import arg_parsers
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.ml_engine import flags
from googlecloudsdk.command_lib.ml_engine import jobs_util
from googlecloudsdk.command_lib.util.args import labels_util


def _AddAcceleratorFlags(parser):
  """Add arguments for accelerator config."""
  accelerator_config_group = base.ArgumentGroup(
      help='Accelerator Configuration.')

  accelerator_config_group.AddArgument(base.Argument(
      '--accelerator-count',
      required=True,
      default=1,
      type=arg_parsers.BoundedInt(lower_bound=1),
      help=('The number of accelerators to attach to the machines.'
            ' Must be >= 1.')))
  accelerator_config_group.AddArgument(
      jobs_util.AcceleratorFlagMap().choice_arg)
  accelerator_config_group.AddToParser(parser)


def _AddSubmitPredictionArgs(parser):
  """Add arguments for `jobs submit prediction` command."""
  parser.add_argument('job', help='Name of the batch prediction job.')
  model_group = parser.add_mutually_exclusive_group(required=True)
  model_group.add_argument(
      '--model-dir',
      help=('Cloud Storage location where '
            'the model files are located.'))
  model_group.add_argument(
      '--model', help='Name of the model to use for prediction.')
  parser.add_argument(
      '--version',
      help="""\
Model version to be used.

This flag may only be given if --model is specified. If unspecified, the default
version of the model will be used. To list versions for a model, run

    $ gcloud ai-platform versions list
""")
  # input location is a repeated field.
  parser.add_argument(
      '--input-paths',
      type=arg_parsers.ArgList(min_length=1),
      required=True,
      metavar='INPUT_PATH',
      help="""\
Cloud Storage paths to the instances to run prediction on.

Wildcards (```*```) accepted at the *end* of a path. More than one path can be
specified if multiple file patterns are needed. For example,

  gs://my-bucket/instances*,gs://my-bucket/other-instances1

will match any objects whose names start with `instances` in `my-bucket` as well
as the `other-instances1` bucket, while

  gs://my-bucket/instance-dir/*

will match any objects in the `instance-dir` "directory" (since directories
aren't a first-class Cloud Storage concept) of `my-bucket`.
""")
  jobs_util.DataFormatFlagMap().choice_arg.AddToParser(parser)
  parser.add_argument(
      '--output-path', required=True,
      help='Cloud Storage path to which to save the output. '
      'Example: gs://my-bucket/output.')
  parser.add_argument(
      '--region',
      required=True,
      help='The Compute Engine region to run the job in.')
  parser.add_argument(
      '--max-worker-count',
      required=False,
      type=int,
      help=('The maximum number of workers to be used for parallel processing. '
            'Defaults to 10 if not specified.'))
  parser.add_argument(
      '--batch-size',
      required=False,
      type=int,
      help=('The number of records per batch. The service will buffer '
            'batch_size number of records in memory before invoking TensorFlow.'
            ' Defaults to 64 if not specified.'))

  flags.SIGNATURE_NAME.AddToParser(parser)
  flags.RUNTIME_VERSION.AddToParser(parser)
  labels_util.AddCreateLabelsFlags(parser)


@base.ReleaseTracks(base.ReleaseTrack.GA,
                    base.ReleaseTrack.BETA)
class Prediction(base.Command):
  """Start an AI Platform batch prediction job."""

  @staticmethod
  def Args(parser):
    _AddSubmitPredictionArgs(parser)
    parser.display_info.AddFormat(jobs_util.JOB_FORMAT)

  def Run(self, args):
    data_format = jobs_util.DataFormatFlagMap().GetEnumForChoice(
        args.data_format)
    jobs_client = jobs.JobsClient()

    labels = jobs_util.ParseCreateLabels(jobs_client, args)
    return jobs_util.SubmitPrediction(
        jobs_client, args.job,
        model_dir=args.model_dir,
        model=args.model,
        version=args.version,
        input_paths=args.input_paths,
        data_format=data_format.name,
        output_path=args.output_path,
        region=args.region,
        runtime_version=args.runtime_version,
        max_worker_count=args.max_worker_count,
        batch_size=args.batch_size,
        signature_name=args.signature_name,
        labels=labels)


@base.ReleaseTracks(base.ReleaseTrack.ALPHA)
class PredictionAlpha(base.Command):
  """Start an AI Platform batch prediction job."""

  @staticmethod
  def Args(parser):
    _AddSubmitPredictionArgs(parser)
    _AddAcceleratorFlags(parser)
    parser.display_info.AddFormat(jobs_util.JOB_FORMAT)

  def Run(self, args):
    data_format = jobs_util.DataFormatFlagMap().GetEnumForChoice(
        args.data_format)
    jobs_client = jobs.JobsClient()

    labels = jobs_util.ParseCreateLabels(jobs_client, args)
    return jobs_util.SubmitPrediction(
        jobs_client, args.job,
        model_dir=args.model_dir,
        model=args.model,
        version=args.version,
        input_paths=args.input_paths,
        data_format=data_format.name,
        output_path=args.output_path,
        region=args.region,
        runtime_version=args.runtime_version,
        max_worker_count=args.max_worker_count,
        batch_size=args.batch_size,
        signature_name=args.signature_name,
        labels=labels,
        accelerator_type=args.accelerator_type,
        accelerator_count=args.accelerator_count)