HEX
Server: Apache/2.4.65 (Ubuntu)
System: Linux ielts-store-v2 6.8.0-1036-gcp #38~22.04.1-Ubuntu SMP Thu Aug 14 01:19:18 UTC 2025 x86_64
User: root (0)
PHP: 7.2.34-54+ubuntu20.04.1+deb.sury.org+1
Disabled: pcntl_alarm,pcntl_fork,pcntl_waitpid,pcntl_wait,pcntl_wifexited,pcntl_wifstopped,pcntl_wifsignaled,pcntl_wifcontinued,pcntl_wexitstatus,pcntl_wtermsig,pcntl_wstopsig,pcntl_signal,pcntl_signal_get_handler,pcntl_signal_dispatch,pcntl_get_last_error,pcntl_strerror,pcntl_sigprocmask,pcntl_sigwaitinfo,pcntl_sigtimedwait,pcntl_exec,pcntl_getpriority,pcntl_setpriority,pcntl_async_signals,
Upload Files
File: //snap/google-cloud-cli/current/lib/googlecloudsdk/command_lib/ml_engine/flags.py
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides common arguments for the AI Platform command surface."""

from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals

import argparse
import functools
import itertools
import sys
import textwrap

from googlecloudsdk.api_lib.ml_engine import jobs
from googlecloudsdk.api_lib.ml_engine import versions_api
from googlecloudsdk.api_lib.storage import storage_util
from googlecloudsdk.calliope import actions
from googlecloudsdk.calliope import arg_parsers
from googlecloudsdk.calliope import base
from googlecloudsdk.calliope.concepts import concepts
from googlecloudsdk.command_lib.iam import completers as iam_completers
from googlecloudsdk.command_lib.iam import iam_util as core_iam_util
from googlecloudsdk.command_lib.kms import resource_args as kms_resource_args
from googlecloudsdk.command_lib.ml_engine import constants
from googlecloudsdk.command_lib.ml_engine import models_util
from googlecloudsdk.command_lib.util.apis import arg_utils
from googlecloudsdk.command_lib.util.concepts import concept_parsers
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import log
from googlecloudsdk.core import properties


_JOB_SUMMARY = """\
table[box,title="Job Overview"](
  jobId,
  createTime,
  startTime,
  endTime,
  state,
  {INPUT},
  {OUTPUT})
"""

_JOB_TRAIN_INPUT_SUMMARY_FORMAT = """\
trainingInput:format='table[box,title="Training Input Summary"](
  runtimeVersion:optional,
  region,
  scaleTier:optional,
  pythonModule,
  parameterServerType:optional,
  parameterServerCount:optional,
  masterType:optional,
  workerType:optional,
  workerCount:optional,
  jobDir:optional
)'
"""

_JOB_TRAIN_OUTPUT_SUMMARY_FORMAT = """\
trainingOutput:format='table[box,title="Training Output Summary"](
  completedTrialCount:optional:label=TRIALS,
  consumedMLUnits:label=ML_UNITS)'
  {HP_OUTPUT}
"""

_JOB_TRAIN_OUTPUT_TRIALS_FORMAT = """\
,trainingOutput.trials.sort(trialId):format='table[box,title="Training Output Trials"](
  trialId:label=TRIAL,
  finalMetric.objectiveValue:label=OBJECTIVE_VALUE,
  finalMetric.trainingStep:label=STEP,
  hyperparameters.list(separator="\n"))'
"""

_JOB_PREDICT_INPUT_SUMMARY_FORMAT = """\
predictionInput:format='table[box,title="Predict Input Summary"](
  runtimeVersion:optional,
  region,
  model.basename():optional,
  versionName.basename(),
  outputPath,
  uri:optional,
  dataFormat,
  batchSize:optional
)'
"""

_JOB_PREDICT_OUTPUT_SUMMARY_FORMAT = """\
predictionOutput:format='table[box,title="Predict Output Summary"](
  errorCount,
  nodeHours,
  outputPath,
  predictionCount
  )'
"""


class ArgumentError(exceptions.Error):
  pass


class MlEngineIamRolesCompleter(iam_completers.IamRolesCompleter):

  def __init__(self, **kwargs):
    super(MlEngineIamRolesCompleter, self).__init__(
        resource_collection=models_util.MODELS_COLLECTION,
        resource_dest='model',
        **kwargs)


def GetDescriptionFlag(noun):
  return base.Argument(
      '--description',
      required=False,
      default=None,
      help='Description of the {noun}.'.format(noun=noun))

# Run flags
DISTRIBUTED = base.Argument(
    '--distributed',
    action='store_true',
    default=False,
    help=('Runs the provided code in distributed mode by providing cluster '
          'configurations as environment variables to subprocesses'))
PARAM_SERVERS = base.Argument(
    '--parameter-server-count',
    type=int,
    help=('Number of parameter servers with which to run. '
          'Ignored if --distributed is not specified. Default: 2'))
WORKERS = base.Argument(
    '--worker-count',
    type=int,
    help=('Number of workers with which to run. '
          'Ignored if --distributed is not specified. Default: 2'))
EVALUATORS = base.Argument(
    '--evaluator-count',
    type=int,
    help=('Number of evaluators with which to run. '
          'Ignored if --distributed is not specified. Default: 0'))
START_PORT = base.Argument(
    '--start-port',
    type=int,
    default=27182,
    help="""\
Start of the range of ports reserved by the local cluster. This command will use
a contiguous block of ports equal to parameter-server-count + worker-count + 1.

If --distributed is not specified, this flag is ignored.
""")


OPERATION_NAME = base.Argument('operation', help='Name of the operation.')


CONFIG = base.Argument(
    '--config',
    help="""\
Path to the job configuration file. This file should be a YAML document (JSON
also accepted) containing a Job resource as defined in the API (all fields are
optional): https://cloud.google.com/ml/reference/rest/v1/projects.jobs

EXAMPLES:\n
JSON:

  {
    "jobId": "my_job",
    "labels": {
      "type": "prod",
      "owner": "alice"
    },
    "trainingInput": {
      "scaleTier": "BASIC",
      "packageUris": [
        "gs://my/package/path"
      ],
      "region": "us-east1"
    }
  }

YAML:

  jobId: my_job
  labels:
    type: prod
    owner: alice
  trainingInput:
    scaleTier: BASIC
    packageUris:
    - gs://my/package/path
    region: us-east1



If an option is specified both in the configuration file **and** via command line
arguments, the command line arguments override the configuration file.
""")
JOB_NAME = base.Argument('job', help='Name of the job.')
PACKAGE_PATH = base.Argument(
    '--package-path',
    help="""\
Path to a Python package to build. This should point to a *local* directory
containing the Python source for the job. It will be built using *setuptools*
(which must be installed) using its *parent* directory as context. If the parent
directory contains a `setup.py` file, the build will use that; otherwise,
it will use a simple built-in one.
""")
PACKAGES = base.Argument(
    '--packages',
    default=[],
    type=arg_parsers.ArgList(),
    metavar='PACKAGE',
    help="""\
Path to Python archives used for training. These can be local paths
(absolute or relative), in which case they will be uploaded to the Cloud
Storage bucket given by `--staging-bucket`, or Cloud Storage URLs
('gs://bucket-name/path/to/package.tar.gz').
""")

_REGION_FLAG_HELPTEXT = """\
Google Cloud region of the regional endpoint to use for this command.
If unspecified, the command uses the global endpoint of the AI Platform Training
and Prediction API.

Learn more about regional endpoints and see a list of available regions:
 https://cloud.google.com/ai-platform/prediction/docs/regional-endpoints
"""

_REGION_FLAG_WITH_GLOBAL_HELPTEXT = """\
Google Cloud region of the regional endpoint to use for this command.
For the global endpoint, the region needs to be specified as `global`.

Learn more about regional endpoints and see a list of available regions:
 https://cloud.google.com/ai-platform/prediction/docs/regional-endpoints
"""


def GetRegionArg(include_global=False):
  """Adds --region flag to determine endpoint for models and versions."""
  if include_global:
    return base.Argument(
        '--region',
        choices=constants.SUPPORTED_REGIONS_WITH_GLOBAL,
        help=_REGION_FLAG_WITH_GLOBAL_HELPTEXT)
  return base.Argument(
      '--region',
      choices=constants.SUPPORTED_REGIONS,
      help=_REGION_FLAG_HELPTEXT)


SERVICE_ACCOUNT = base.Argument(
    '--service-account',
    required=False,
    help='Specifies the service account for resource access control.')

NETWORK = base.Argument(
    '--network',
    help="""\
Full name of the Google Compute Engine
network (https://cloud.google.com/vpc/docs) to which the Job
is peered with. For example, ``projects/12345/global/networks/myVPC''. The
format is of the form projects/{project}/global/networks/{network}, where
{project} is a project number, as in '12345', and {network} is network name.
Private services access must already have been configured
(https://cloud.google.com/vpc/docs/configure-private-services-access)
for the network. If unspecified, the Job is not peered with any network.
""")

TRAINING_SERVICE_ACCOUNT = base.Argument(
    '--service-account',
    type=core_iam_util.GetIamAccountFormatValidator(),
    required=False,
    help=textwrap.dedent("""\
      The email address of a service account to use when running the
      training appplication. You must have the `iam.serviceAccounts.actAs`
      permission for the specified service account. In addition, the AI Platform
      Training Google-managed service account must have the
      `roles/iam.serviceAccountAdmin` role for the specified service account.
      [Learn more about configuring a service
      account.](https://cloud.google.com/ai-platform/training/docs/custom-service-account)
      If not specified, the AI Platform Training Google-managed service account
      is used by default.
      """))

ENABLE_WEB_ACCESS = base.Argument(
    '--enable-web-access',
    action='store_true',
    required=False,
    default=False,
    help=textwrap.dedent("""\
      Whether you want AI Platform Training to enable [interactive shell
      access]
      (https://cloud.google.com/ai-platform/training/docs/monitor-debug-interactive-shell)
      to training containers. If set to `true`, you can access interactive
      shells at the URIs given by TrainingOutput.web_access_uris or
      HyperparameterOutput.web_access_uris (within TrainingOutput.trials).
      """))


def GetModuleNameFlag(required=True):
  return base.Argument(
      '--module-name', required=required, help='Name of the module to run.')


def GetJobDirFlag(upload_help=True, allow_local=False):
  """Get base.Argument() for `--job-dir`.

  If allow_local is provided, this Argument gives a str when parsed; otherwise,
  it gives a (possibly empty) ObjectReference.

  Args:
    upload_help: bool, whether to include help text related to object upload.
      Only useful in remote situations (`jobs submit training`).
    allow_local: bool, whether to allow local directories (only useful in local
      situations, like `local train`) or restrict input to directories in Cloud
      Storage.

  Returns:
    base.Argument() for the corresponding `--job-dir` flag.
  """
  help_ = """\
{dir_type} in which to store training outputs and other data
needed for training.

This path will be passed to your TensorFlow program as the `--job-dir` command-line
arg. The benefit of specifying this field is that AI Platform will validate
the path for use in training. However, note that your training program will need
to parse the provided `--job-dir` argument.
""".format(
    dir_type=('Cloud Storage path' +
              (' or local_directory' if allow_local else '')))
  if upload_help:
    help_ += """\

If packages must be uploaded and `--staging-bucket` is not provided, this path
will be used instead.
"""

  if allow_local:
    type_ = str
  else:
    type_ = functools.partial(storage_util.ObjectReference.FromArgument,
                              allow_empty_object=True)
  return base.Argument('--job-dir', type=type_, help=help_)


def GetUserArgs(local=False):
  if local:
    help_text = """\
Additional user arguments to be forwarded to user code. Any relative paths will
be relative to the *parent* directory of `--package-path`.
"""
  else:
    help_text = 'Additional user arguments to be forwarded to user code'
  return base.Argument(
      'user_args',
      nargs=argparse.REMAINDER,
      help=help_text)


VERSION_NAME = base.Argument('version', help='Name of the model version.')

RUNTIME_VERSION = base.Argument(
    '--runtime-version',
    help=(
        'AI Platform runtime version for this job. Must be specified unless '
        '--master-image-uri is specified instead. It is defined in '
        'documentation along with the list of supported versions: '
        'https://cloud.google.com/ai-platform/prediction/docs/runtime-version-list'  # pylint: disable=line-too-long
    ))


POLLING_INTERVAL = base.Argument(
    '--polling-interval',
    type=arg_parsers.BoundedInt(1, sys.maxsize, unlimited=True),
    required=False,
    default=60,
    action=actions.StoreProperty(properties.VALUES.ml_engine.polling_interval),
    help='Number of seconds to wait between efforts to fetch the latest '
    'log messages.')
ALLOW_MULTILINE_LOGS = base.Argument(
    '--allow-multiline-logs',
    action='store_true',
    help='Output multiline log messages as single records.')
TASK_NAME = base.Argument(
    '--task-name',
    required=False,
    default=None,
    help='If set, display only the logs for this particular task.')


_FRAMEWORK_CHOICES = {
    'TENSORFLOW': 'tensorflow',
    'SCIKIT_LEARN': 'scikit-learn',
    'XGBOOST': 'xgboost'
}
FRAMEWORK_MAPPER = arg_utils.ChoiceEnumMapper(
    '--framework',
    (versions_api.GetMessagesModule().
     GoogleCloudMlV1Version.FrameworkValueValuesEnum),
    custom_mappings=_FRAMEWORK_CHOICES,
    help_str=('ML framework used to train this version of the model. '
              'If not specified, defaults to \'tensorflow\''))


def AddKmsKeyFlag(parser, resource):
  permission_info = '{} must hold permission {}'.format(
      "The 'AI Platform Service Agent' service account",
      "'Cloud KMS CryptoKey Encrypter/Decrypter'")
  kms_resource_args.AddKmsKeyResourceArg(
      parser, resource, permission_info=permission_info)


def AddPythonVersionFlag(parser, context):
  help_str = """\
      Version of Python used {context}. Choices are 3.7, 3.5, and 2.7.
      However, this value must be compatible with the chosen runtime version
      for the job.

      Must be used with a compatible runtime version:

      * 3.7 is compatible with runtime versions 1.15 and later.
      * 3.5 is compatible with runtime versions 1.4 through 1.14.
      * 2.7 is compatible with runtime versions 1.15 and earlier.
      """.format(context=context)
  version = base.Argument(
      '--python-version',
      help=help_str)
  version.AddToParser(parser)


def GetModelName(positional=True, required=False):
  help_text = 'Name of the model.'
  if positional:
    return base.Argument('model', help=help_text)
  else:
    return base.Argument('--model', help=help_text, required=required)


def ProcessPackages(args):
  """Flatten PACKAGES flag and warn if multiple arguments were used."""
  if args.packages is not None:
    if len(args.packages) > 1:
      log.warning('Use of --packages with space separated values is '
                  'deprecated and will not work in the future. Use comma '
                  'instead.')
    # flatten packages into a single list
    args.packages = list(itertools.chain.from_iterable(args.packages))


STAGING_BUCKET = base.Argument(
    '--staging-bucket',
    type=storage_util.BucketReference.FromArgument,
    help="""\
        Bucket in which to stage training archives.

        Required only if a file upload is necessary (that is, other flags
        include local paths) and no other flags implicitly specify an upload
        path.
        """)

SIGNATURE_NAME = base.Argument(
    '--signature-name',
    required=False,
    type=str,
    help="""\
    Name of the signature defined in the SavedModel to use for
    this job. Defaults to DEFAULT_SERVING_SIGNATURE_DEF_KEY in
    https://www.tensorflow.org/api_docs/python/tf/compat/v1/saved_model/signature_constants,
    which is "serving_default". Only applies to TensorFlow models.
    """)


def GetSummarizeFlag():
  return base.Argument(
      '--summarize',
      action='store_true',
      required=False,
      help="""\
      Summarize job output in a set of human readable tables instead of
      rendering the entire resource as json or yaml. The tables currently rendered
      are:

      * `Job Overview`: Overview of job including, jobId, status and create time.
      * `Training Input Summary`: Summary of input for a training job including
         region, main training python module and scale tier.
      * `Training Output Summary`: Summary of output for a training job including
         the amount of ML units consumed by the job.
      * `Training Output Trials`: Summary of hyperparameter trials run for a
         hyperparameter tuning training job.
      * `Predict Input Summary`: Summary of input for a prediction job including
         region, model verion and output path.
      * `Predict Output Summary`: Summary of output for a prediction job including
         prediction count and output path.

      This flag overrides the `--format` flag. If
      both are present on the command line, a warning is displayed.
      """)


def GetStandardTrainingJobSummary():
  """Get tabular format for standard ml training job."""
  return _JOB_SUMMARY.format(
      INPUT=_JOB_TRAIN_INPUT_SUMMARY_FORMAT,
      OUTPUT=_JOB_TRAIN_OUTPUT_SUMMARY_FORMAT.format(HP_OUTPUT=''))


def GetHPTrainingJobSummary():
  """Get tablular format to HyperParameter tuning ml job."""
  return _JOB_SUMMARY.format(
      INPUT=_JOB_PREDICT_INPUT_SUMMARY_FORMAT,
      OUTPUT=_JOB_TRAIN_OUTPUT_SUMMARY_FORMAT.format(
          HP_OUTPUT=_JOB_TRAIN_OUTPUT_TRIALS_FORMAT))


def GetPredictJobSummary():
  """Get table format for ml prediction job."""
  return _JOB_SUMMARY.format(
      INPUT=_JOB_PREDICT_INPUT_SUMMARY_FORMAT,
      OUTPUT=_JOB_PREDICT_OUTPUT_SUMMARY_FORMAT)


def ModelAttributeConfig():
  return concepts.ResourceParameterAttributeConfig(
      name='model',
      help_text='Model for the {resource}.')


def VersionAttributeConfig():
  return concepts.ResourceParameterAttributeConfig(
      name='version',
      help_text='Version for the {resource}.')


def GetVersionResourceSpec():
  return concepts.ResourceSpec(
      'ml.projects.models.versions',
      resource_name='version',
      versionsId=VersionAttributeConfig(),
      modelsId=ModelAttributeConfig(),
      projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG)


def AddVersionResourceArg(parser, verb):
  """Add a resource argument for an AI Platform version."""
  concept_parsers.ConceptParser.ForResource(
      'version',
      GetVersionResourceSpec(),
      'The AI Platform model {}.'.format(verb),
      required=True).AddToParser(parser)


def GetModelResourceSpec(resource_name='model'):
  return concepts.ResourceSpec(
      'ml.projects.models',
      resource_name=resource_name,
      projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG,
      disable_auto_completers=True)


def GetModelResourceArg(positional=True, required=False, verb=''):
  """Add a resource argument for AI Platform model.

  NOTE: Must be used only if it's the only resource arg in the command.

  Args:
    positional: bool, if True, means that the model is a positional rather
    required: bool, if True means that argument is required.
    verb: str, the verb to describe the resource, such as 'to update'.

  Returns:
    An argparse.ArgumentParse object.
  """
  if positional:
    name = 'model'
  else:
    name = '--model'
  return concept_parsers.ConceptParser.ForResource(
      name,
      GetModelResourceSpec(),
      'The AI Platform model {}.'.format(verb),
      required=required)


def GetLocationResourceSpec():
  return concepts.ResourceSpec(
      'ml.projects.locations',
      resource_name='location',
      projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG,
      disable_auto_completers=True)


def GetLocationResourceArg():
  """Add a resource argument for AI Platform location.

  Returns:
    An argparse.ArgumentParse object.
  """
  return concept_parsers.ConceptParser.ForResource(
      'location',
      GetLocationResourceSpec(),
      'The location you want to show details for.',
      required=True)


def AddUserCodeArgs(parser):
  """Add args that configure user prediction code."""
  user_code_group = base.ArgumentGroup(help="""\
          Configure user code in prediction.

          AI Platform allows a model to have user-provided prediction
          code; these options configure that code.
          """)
  user_code_group.AddArgument(base.Argument(
      '--prediction-class',
      help="""\
          Fully-qualified name of the custom prediction class in the package
          provided for custom prediction.

          For example, `--prediction-class=my_package.SequenceModel`.
          """))
  user_code_group.AddArgument(base.Argument(
      '--package-uris',
      type=arg_parsers.ArgList(),
      metavar='PACKAGE_URI',
      help="""\
          Comma-separated list of Cloud Storage URIs ('gs://...') for
          user-supplied Python packages to use.
          """))
  user_code_group.AddToParser(parser)


def GetAcceleratorFlag():
  return base.Argument(
      '--accelerator',
      type=arg_parsers.ArgDict(
          spec={
              'type': str,
              'count': int,
          }, required_keys=['type']),
      help="""\
Manage the accelerator config for GPU serving. When deploying a model with
Compute Engine Machine Types, a GPU accelerator may also
be selected.

*type*::: The type of the accelerator. Choices are {}.

*count*::: The number of accelerators to attach to each machine running the job.
 If not specified, the default value is 1. Your model must be specially designed
to accommodate more than 1 accelerator per machine. To configure how many
replicas your model has, set the `manualScaling` or `autoScaling`
parameters.""".format(', '.join(
    ["'{}'".format(c) for c in _OP_ACCELERATOR_TYPE_MAPPER.choices])))


def ParseAcceleratorFlag(accelerator):
  """Validates and returns a accelerator config message object."""
  types = [c for c in _ACCELERATOR_TYPE_MAPPER.choices]
  if accelerator is None:
    return None
  raw_type = accelerator.get('type', None)
  if raw_type not in types:
    raise ArgumentError("""\
The type of the accelerator can only be one of the following: {}.
""".format(', '.join(["'{}'".format(c) for c in types])))
  accelerator_count = accelerator.get('count', 1)
  if accelerator_count <= 0:
    raise ArgumentError("""\
The count of the accelerator must be greater than 0.
""")
  accelerator_msg = (versions_api.GetMessagesModule().
                     GoogleCloudMlV1AcceleratorConfig)
  accelerator_type = arg_utils.ChoiceToEnum(
      raw_type, accelerator_msg.TypeValueValuesEnum)
  return accelerator_msg(
      count=accelerator_count,
      type=accelerator_type)


def AddExplainabilityFlags(parser):
  """Add args that configure explainability."""
  base.ChoiceArgument(
      '--explanation-method',
      choices=['integrated-gradients', 'sampled-shapley', 'xrai'],
      required=False,
      help_str="""\
          Enable explanations and select the explanation method to use.

          The valid options are:
            integrated-gradients: Use Integrated Gradients.
            sampled-shapley: Use Sampled Shapley.
            xrai: Use XRAI.
      """
  ).AddToParser(parser)
  base.Argument(
      '--num-integral-steps',
      type=arg_parsers.BoundedInt(1, sys.maxsize, unlimited=True),
      default=50,
      required=False,
      help="""\
          Number of integral steps for Integrated Gradients. Only valid when
          `--explanation-method=integrated-gradients` or
          `--explanation-method=xrai` is specified.
      """
  ).AddToParser(parser)
  base.Argument(
      '--num-paths',
      type=arg_parsers.BoundedInt(1, sys.maxsize, unlimited=True),
      default=50,
      required=False,
      help="""\
          Number of paths for Sampled Shapley. Only valid when
          `--explanation-method=sampled-shapley` is specified.
      """
  ).AddToParser(parser)


def AddCustomContainerFlags(parser, support_tpu_tf_version=False):
  """Add Custom container flags to parser."""
  GetMasterMachineType().AddToParser(parser)
  GetMasterAccelerator().AddToParser(parser)
  GetMasterImageUri().AddToParser(parser)
  GetParameterServerMachineTypeConfig().AddToParser(parser)
  GetParameterServerAccelerator().AddToParser(parser)
  GetParameterServerImageUri().AddToParser(parser)
  GetWorkerMachineConfig().AddToParser(parser)
  GetWorkerAccelerator().AddToParser(parser)
  GetWorkerImageUri().AddToParser(parser)
  GetUseChiefInTfConfig().AddToParser(parser)
  if support_tpu_tf_version:
    GetTpuTfVersion().AddToParser(parser)

# Custom Container Flags
_ACCELERATOR_TYPE_MAPPER = arg_utils.ChoiceEnumMapper(
    'generic-accelerator',
    jobs.GetMessagesModule(
    ).GoogleCloudMlV1AcceleratorConfig.TypeValueValuesEnum,
    help_str='Available types of accelerators.',
    include_filter=lambda x: x != 'ACCELERATOR_TYPE_UNSPECIFIED',
    required=False)

_OP_ACCELERATOR_TYPE_MAPPER = arg_utils.ChoiceEnumMapper(
    'generic-accelerator',
    jobs.GetMessagesModule().GoogleCloudMlV1AcceleratorConfig
    .TypeValueValuesEnum,
    help_str='Available types of accelerators.',
    include_filter=lambda x: x.startswith('NVIDIA'),
    required=False)

_OP_AUTOSCALING_METRIC_NAME_MAPPER = arg_utils.ChoiceEnumMapper(
    'autoscaling-metric-name',
    versions_api.GetMessagesModule().GoogleCloudMlV1MetricSpec
    .NameValueValuesEnum,
    help_str='The available metric names.',
    include_filter=lambda x: x != 'METRIC_NAME_UNSPECIFIED',
    required=False)

_ACCELERATOR_TYPE_HELP = """\
   Hardware accelerator config for the {worker_type}. Must specify
   both the accelerator type (TYPE) for each server and the number of
   accelerators to attach to each server (COUNT).
  """


def _ConvertAcceleratorTypeToEnumValue(choice_str):
  return arg_utils.ChoiceToEnum(choice_str, _ACCELERATOR_TYPE_MAPPER.enum,
                                item_type='accelerator',
                                valid_choices=_ACCELERATOR_TYPE_MAPPER.choices)


def _ValidateAcceleratorCount(accelerator_count):
  count = int(accelerator_count)
  if count <= 0:
    raise arg_parsers.ArgumentTypeError(
        'The count of the accelerator must be greater than 0.')
  return count


def _ValidateMetricTargetKey(key):
  """Value validation for Metric target name."""
  names = list(_OP_AUTOSCALING_METRIC_NAME_MAPPER.choices)
  if key not in names:
    raise ArgumentError("""\
The autoscaling metric name can only be one of the following: {}.
""".format(', '.join(["'{}'".format(c) for c in names])))
  return key


def _ValidateMetricTargetValue(value):
  """Value validation for Metric target value."""
  try:
    result = int(value)
  except (TypeError, ValueError):
    raise ArgumentError('Metric target percentage value %s is not an integer.' %
                        value)

  if result < 0 or result > 100:
    raise ArgumentError(
        'Metric target value %s is not between 0 and 100.' % value)
  return result


def _MakeAcceleratorArgConfigArg(arg_name, arg_help, required=False):
  """Creates an ArgDict representing an AcceleratorConfig message."""
  type_help = '*type*::: Type of the accelerator. Choices are {}'.format(
      ','.join(_ACCELERATOR_TYPE_MAPPER.choices))
  count_help = ('*count*::: Number of accelerators to attach to each '
                'machine running the job. Must be greater than 0.')
  arg = base.Argument(
      arg_name,
      required=required,
      type=arg_parsers.ArgDict(spec={
          'type': _ConvertAcceleratorTypeToEnumValue,
          'count': _ValidateAcceleratorCount,
      }, required_keys=['type', 'count']),
      help="""\
{arg_help}

{type_help}

{count_help}

""".format(arg_help=arg_help, type_help=type_help, count_help=count_help))
  return arg


def GetMasterMachineType():
  """Build master-machine-type flag."""
  help_text = """\
  Specifies the type of virtual machine to use for training job's master worker.

  You must set this value when `--scale-tier` is set to `CUSTOM`.
  """
  return base.Argument(
      '--master-machine-type', required=False, help=help_text)


def GetMasterAccelerator():
  """Build master-accelerator flag."""
  help_text = _ACCELERATOR_TYPE_HELP.format(worker_type='master worker')
  return _MakeAcceleratorArgConfigArg(
      '--master-accelerator', arg_help=help_text)


def GetMasterImageUri():
  """Build master-image-uri flag."""
  return base.Argument(
      '--master-image-uri',
      required=False,
      help=('Docker image to run on each master worker. '
            'This image must be in Container Registry. Only one of '
            '`--master-image-uri` and `--runtime-version` must be specified.'))


def GetParameterServerMachineTypeConfig():
  """Build parameter-server machine type config group."""
  machine_type = base.Argument(
      '--parameter-server-machine-type',
      required=True,
      help=('Type of virtual machine to use for training job\'s '
            'parameter servers. This flag must be specified if any of the '
            'other arguments in this group are specified machine to use for '
            'training job\'s parameter servers.'))

  machine_count = base.Argument(
      '--parameter-server-count',
      type=arg_parsers.BoundedInt(1, sys.maxsize, unlimited=True),
      required=True,
      help=('Number of parameter servers to use for the training job.'))

  machine_type_group = base.ArgumentGroup(
      required=False,
      help='Configure parameter server machine type settings.')
  machine_type_group.AddArgument(machine_type)
  machine_type_group.AddArgument(machine_count)
  return machine_type_group


def GetParameterServerAccelerator():
  """Build parameter-server-accelerator flag."""
  help_text = _ACCELERATOR_TYPE_HELP.format(worker_type='parameter servers')
  return _MakeAcceleratorArgConfigArg(
      '--parameter-server-accelerator', arg_help=help_text)


def GetParameterServerImageUri():
  """Build parameter-server-image-uri flag."""
  return base.Argument(
      '--parameter-server-image-uri',
      required=False,
      help=('Docker image to run on each parameter server. '
            'This image must be in Container Registry. If not '
            'specified, the value of `--master-image-uri` is used.'))


def GetWorkerMachineConfig():
  """Build worker machine type config group."""
  machine_type = base.Argument(
      '--worker-machine-type',
      required=True,
      help=('Type of virtual '
            'machine to use for training '
            'job\'s worker nodes.'))

  machine_count = base.Argument(
      '--worker-count',
      type=arg_parsers.BoundedInt(1, sys.maxsize, unlimited=True),
      required=True,
      help='Number of worker nodes to use for the training job.')

  machine_type_group = base.ArgumentGroup(
      required=False,
      help='Configure worker node machine type settings.')
  machine_type_group.AddArgument(machine_type)
  machine_type_group.AddArgument(machine_count)
  return machine_type_group


def GetWorkerAccelerator():
  """Build worker-accelerator flag."""
  help_text = _ACCELERATOR_TYPE_HELP.format(worker_type='worker nodes')
  return _MakeAcceleratorArgConfigArg(
      '--worker-accelerator', arg_help=help_text)


def GetWorkerImageUri():
  """Build worker-image-uri flag."""
  return base.Argument(
      '--worker-image-uri',
      required=False,
      help=('Docker image to run on each worker node. '
            'This image must be in Container Registry. If not '
            'specified, the value of `--master-image-uri` is used.'))


def GetTpuTfVersion():
  """Build tpu-tf-version flag."""
  return base.Argument(
      '--tpu-tf-version',
      required=False,
      help=('Runtime version of TensorFlow used by the container. This field '
            'must be specified if a custom container on the TPU worker is '
            'being used.'))


def GetUseChiefInTfConfig():
  """Build use-chief-in-tf-config flag."""
  return base.Argument(
      '--use-chief-in-tf-config',
      required=False,
      type=arg_parsers.ArgBoolean(),
      help=('Use "chief" role in the cluster instead of "master". This is '
            'required for TensorFlow 2.0 and newer versions. Unlike "master" '
            'node, "chief" node does not run evaluation.'))


def AddMachineTypeFlagToParser(parser):
  base.Argument(
      '--machine-type',
      help="""\
Type of machine on which to serve the model. Currently only applies to online prediction. For available machine types,
see https://cloud.google.com/ai-platform/prediction/docs/machine-types-online-prediction#available_machine_types.
""").AddToParser(parser)


def AddContainerFlags(parser):
  """Adds flags related to custom containers to the specified parser."""
  container_group = parser.add_argument_group(
      help='Configure the container to be deployed.')
  container_group.add_argument(
      '--image',
      help="""\
Name of the container image to deploy (e.g. gcr.io/myproject/server:latest).
""")
  container_group.add_argument(
      '--command',
      type=arg_parsers.ArgList(),
      metavar='COMMAND',
      action=arg_parsers.UpdateAction,
      help="""\
Entrypoint for the container image. If not specified, the container
image's default Entrypoint is run.
""")
  container_group.add_argument(
      '--args',
      metavar='ARG',
      type=arg_parsers.ArgList(),
      action=arg_parsers.UpdateAction,
      help="""\
Comma-separated arguments passed to the command run by the container
image. If not specified and no '--command' is provided, the container
image's default Cmd is used.
""")
  container_group.add_argument(
      '--env-vars',
      metavar='KEY=VALUE',
      type=arg_parsers.ArgDict(),
      action=arg_parsers.UpdateAction,
      help='List of key-value pairs to set as environment variables.'
  )
  container_group.add_argument(
      '--ports',
      metavar='ARG',
      type=arg_parsers.ArgList(element_type=arg_parsers.BoundedInt(1, 65535)),
      action=arg_parsers.UpdateAction,
      help="""\
Container ports to receive requests at. Must be a number between 1 and 65535,
inclusive.
""")
  route_group = parser.add_argument_group(
      help='Flags to control the paths that requests and health checks are '
      'sent to.')
  route_group.add_argument(
      '--predict-route',
      help='HTTP path to send prediction requests to inside the container.'
  )
  route_group.add_argument(
      '--health-route',
      help='HTTP path to send health checks to inside the container.'
  )


def AddAutoScalingFlags(parser):
  """Adds flags related to autoscaling to the specified parser."""
  autoscaling_group = parser.add_argument_group(
      help='Configure the autoscaling settings to be deployed.')
  autoscaling_group.add_argument(
      '--min-nodes',
      type=int,
      help="""\
The minimum number of nodes to scale this model under load.
""")
  autoscaling_group.add_argument(
      '--max-nodes',
      type=int,
      help="""\
The maximum number of nodes to scale this model under load.
""")
  autoscaling_group.add_argument(
      '--metric-targets',
      metavar='METRIC-NAME=TARGET',
      type=arg_parsers.ArgDict(
          key_type=_ValidateMetricTargetKey, value_type=_ValidateMetricTargetValue),
      action=arg_parsers.UpdateAction,
      default={},
      help="""\
List of key-value pairs to set as metrics' target for autoscaling.
Autoscaling could be based on CPU usage or GPU duty cycle, valid key could be
cpu-usage or gpu-duty-cycle.
""")


def AddRequestLoggingConfigFlags(parser):
  """Adds flags related to request response logging."""
  group = parser.add_argument_group(
      help='Configure request response logging.')
  group.add_argument(
      '--bigquery-table-name',
      type=str,
      help="""\
Fully qualified name (project_id.dataset_name.table_name) of the BigQuery
table where requests and responses are logged.
""")
  group.add_argument(
      '--sampling-percentage',
      type=float,
      help="""\
Percentage of requests to be logged, expressed as a number between 0 and 1.
For example, set this value to 1 in order to log all requests or to 0.1 to log
10% of requests.
""")