File: //snap/google-cloud-cli/current/lib/googlecloudsdk/command_lib/compute/tpus/tpu_vm/util.py
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CLI Utilities for Cloud TPU VM commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import re
import sys
from googlecloudsdk.api_lib.compute import base_classes
from googlecloudsdk.api_lib.compute import metadata_utils
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.calliope import base
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.command_lib.util.args import map_util
from googlecloudsdk.core import exceptions as sdk_core_exceptions
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
import six
class NoFieldsSpecifiedError(sdk_core_exceptions.Error):
"""Error if no fields are specified for an update request."""
class AttachDiskError(sdk_core_exceptions.Error):
"""Error if the update request is invalid for attaching a disk."""
class DetachDiskError(sdk_core_exceptions.Error):
"""Error if the update request is invalid for detaching a disk."""
class BootDiskConfigurationError(sdk_core_exceptions.Error):
"""Error if the boot disk configuration is invalid."""
class WorkerIdsError(sdk_core_exceptions.Error):
"""Error if the boot disk configuration is invalid."""
def GetProject(release_track, ssh_helper):
holder = base_classes.ComputeApiHolder(release_track)
project_name = properties.VALUES.core.project.GetOrFail()
return ssh_helper.GetProject(holder.client, project_name)
def InvertBoolean(boolean):
"""Inverts the boolean value passed in."""
return not boolean
def MergeMetadata(api_version='v2'):
"""Request hook for merging the metadata and metadata from file."""
def Process(unused_ref, args, request):
"""Request hook for merging the metadata and metadata from file.
Args:
unused_ref: ref to the service.
args: The args for this method.
request: The request to be made.
Returns:
Request with metadata field populated.
"""
metadata_dict = metadata_utils.ConstructMetadataDict(
args.metadata, args.metadata_from_file)
tpu_messages = GetMessagesModule(version=api_version)
if request.node.metadata is None:
request.node.metadata = tpu_messages.Node.MetadataValue()
for key, value in six.iteritems(metadata_dict):
request.node.metadata.additionalProperties.append(
tpu_messages.Node.MetadataValue.AdditionalProperty(
key=key, value=value))
return request
return Process
def GetTagsUpdateFromArgs(args, tags):
"""Generate the change to the tags on a resource based on the arguments.
Args:
args: The args for this method.
tags: The current list of tags.
Returns:
The change to the tags after all of the arguments are applied.
"""
tags_update = tags
if args.IsKnownAndSpecified('clear_tags'):
tags_update = []
if args.IsKnownAndSpecified('add_tags'):
tags_update = sorted(set(tags_update + args.add_tags))
if args.IsKnownAndSpecified('remove_tags'):
tags_update = sorted(set(tags_update) - set(args.remove_tags))
return tags_update
def GenerateUpdateMask(api_version='v2'):
"""Request hook for constructing the updateMask for update requests."""
def Process(unused_ref, args, request):
"""Request hook for constructing the updateMask for update requests.
Args:
unused_ref: ref to the service.
args: The args for this method.
request: The request to be made.
Returns:
Request with updateMask field populated.
Raises:
NoFieldsSpecifiedError: if no fields were specified.
AttachDiskError: if the request for attaching a disk is invalid.
DetachDiskError: if the request for detaching a disk is invalid.
"""
update_mask = set()
tpu_messages = GetMessagesModule(version=api_version)
# Since it's possible that different API versions support different flags,
# we must check that the flag is both known in this version and if it is
# specified.
if args.IsKnownAndSpecified('description'):
update_mask.add('description')
if args.IsKnownAndSpecified('internal_ips'):
update_mask.add('network_config.enable_external_ips')
if (args.IsKnownAndSpecified('update_labels') or
args.IsKnownAndSpecified('remove_labels') or
args.IsKnownAndSpecified('clear_labels')):
labels_diff = labels_util.Diff.FromUpdateArgs(args)
if labels_diff.MayHaveUpdates():
labels_update = labels_diff.Apply(
tpu_messages.Node.LabelsValue,
request.node.labels)
if labels_update.needs_update:
request.node.labels = labels_update.labels
update_mask.add('labels')
if (args.IsKnownAndSpecified('add_tags') or
args.IsKnownAndSpecified('remove_tags') or
args.IsKnownAndSpecified('clear_tags')):
tags_update = GetTagsUpdateFromArgs(args, request.node.tags)
if set(tags_update) != set(request.node.tags):
request.node.tags = tags_update
update_mask.add('tags')
if args.IsKnownAndSpecified('metadata_from_file'):
metadata_dict = metadata_utils.ConstructMetadataDict(
None, args.metadata_from_file)
request.node.metadata = tpu_messages.Node.MetadataValue()
for key, value in six.iteritems(metadata_dict):
request.node.metadata.additionalProperties.append(
tpu_messages.Node.MetadataValue.AdditionalProperty(
key=key, value=value))
update_mask.add('metadata')
elif (args.IsKnownAndSpecified('update_metadata') or
args.IsKnownAndSpecified('remove_metadata') or
args.IsKnownAndSpecified('clear_metadata')):
metadata_dict = {}
if request.node.metadata is not None:
for item in request.node.metadata.additionalProperties:
metadata_dict[item.key] = item.value
# Apply flags one by one since we allow multiple flags to be set at once.
# The order should match the flags' descriptions.
metadata_update = map_util.ApplyMapFlags(metadata_dict, None,
None, args.clear_metadata, None,
None)
metadata_update = map_util.ApplyMapFlags(metadata_update, None,
args.update_metadata, None, None,
None)
metadata_update = map_util.ApplyMapFlags(metadata_update, None, None,
None, args.remove_metadata, None)
request.node.metadata = tpu_messages.Node.MetadataValue()
for key, value in six.iteritems(metadata_update):
request.node.metadata.additionalProperties.append(
tpu_messages.Node.MetadataValue.AdditionalProperty(
key=key, value=value))
update_mask.add('metadata')
if args.IsKnownAndSpecified('attach_disk'):
# validates worker
if not args.IsKnownAndSpecified('worker'):
args.worker = ['all']
is_all_workers_specified = ValidateWorkerIdsField(args)
if is_all_workers_specified:
args.worker = []
mode, source = '', ''
for key in args.attach_disk.keys():
if key == 'mode':
mode = args.attach_disk['mode']
elif key == 'source':
source = args.attach_disk['source']
else:
raise AttachDiskError(
'argument --attach-disk: valid keys are [mode, source]; '
'received: ' + key
)
if mode == 'read-only':
mode_enum = tpu_messages.AttachedDisk.ModeValueValuesEnum.READ_ONLY
elif not mode or mode == 'read-write':
mode_enum = tpu_messages.AttachedDisk.ModeValueValuesEnum.READ_WRITE
if len(args.worker) > 1:
raise AttachDiskError(
'argument --attach-disk: can only attach disks in read-write'
' to at most one worker; received: ' + str(args.worker)
)
else:
raise AttachDiskError(
'argument --attach-disk: key mode: can only attach disks in '
'read-write or read-only mode; received: ' + mode
)
# worker is de-duped and sorted.
worker = set(args.worker)
disk_to_attach = tpu_messages.AttachedDisk(
mode=mode_enum,
sourceDisk=source,
)
if api_version == 'v2alpha1':
disk_to_attach.workerIds = sorted(worker)
PreprocessDiskToAttach(request.node.dataDisks, disk_to_attach)
request.node.dataDisks.append(disk_to_attach)
update_mask.add('data_disks')
elif args.IsKnownAndSpecified('detach_disk'):
# validates worker
if not args.IsKnownAndSpecified('worker'):
args.worker = ['all']
is_all_workers_specified = ValidateWorkerIdsField(args)
if is_all_workers_specified:
args.worker = []
if not request.node.dataDisks:
raise DetachDiskError(
'argument --detach-disk: No data disks to detach from current TPU '
'VM.'
)
source_disk_list = []
for disk in request.node.dataDisks:
source_disk_list.append(disk.sourceDisk)
for i, source_disk in enumerate(source_disk_list):
if args.detach_disk != source_disk:
continue
if is_all_workers_specified:
del request.node.dataDisks[i]
break
worker_diff = set(
request.node.dataDisks[i].workerIds) - set(args.worker)
if not worker_diff:
del request.node.dataDisks[i]
break
request.node.dataDisks[i].workerIds = sorted(worker_diff)
break
else:
raise DetachDiskError(
'argument --detach-disk: The specified data disk '
+ args.detach_disk + ' is not currently attached to the TPU VM.'
)
update_mask.add('data_disks')
if not update_mask:
raise NoFieldsSpecifiedError(
'No fields would change as a result of this update; must specify at '
'least one field to update.')
request.updateMask = ','.join(update_mask)
return request
return Process
def RemoveConflictingDefaults(unused_ref, args, request):
"""Unset acceleratorType flag when it conflicts with topology arguments.
Args:
unused_ref: ref to the service.
args: The args for this method.
request: The request to be made.
Returns:
Request with metadata field populated.
"""
if args.topology is not None:
request.node.acceleratorType = None
return request
def GetMessagesModule(version='v2'):
return apis.GetMessagesModule('tpu', version)
def StartRequestHook(api_version='v2'):
"""Declarative request hook for TPU Start command."""
def Process(ref, args, request):
del ref
del args
start_request = GetMessagesModule(version=api_version).StartNodeRequest()
request.startNodeRequest = start_request
return request
return Process
def StopRequestHook(api_version='v2'):
"""Declarative request hook for TPU Stop command."""
def Process(ref, args, request):
del ref
del args
stop_request = GetMessagesModule(version=api_version).StopNodeRequest()
request.stopNodeRequest = stop_request
return request
return Process
def IsTPUVMNode(node):
api_version = six.text_type(node.apiVersion).upper()
return (not api_version.startswith('V1')
and api_version != 'API_VERSION_UNSPECIFIED')
def FilterTPUVMNodes(response, args):
"""Removes Cloud TPU V1 API nodes from the 'list' output.
Used with 'compute tpus tpu-vm list'.
Args:
response: response to ListNodes.
args: the arguments for the list command.
Returns:
A response with only TPU VM (non-V1 API) nodes.
"""
del args
return list(six.moves.filter(IsTPUVMNode, response))
class GuestAttributesListEntry(object):
"""Holder for GetGuestAttributes output."""
def __init__(self, worker_id, namespace, key, value):
self.worker_id = worker_id
self.namespace = namespace
self.key = key
self.value = value
def TransformGuestAttributes(response, args):
"""Transforms the GuestAttributes into a flatter list.
This is needed to make clearer output in the case of TPU pods, since they have
many workers.
Args:
response: response to GetGuestAttributes.
args: the arguments for the GetGuestAttributes command.
Returns:
A list of GuestAttributesListEntry objects.
"""
del args
lst = []
for i, ga in enumerate(response.guestAttributes):
for entry in ga.queryValue.items:
lst.append(
GuestAttributesListEntry(i, entry.namespace, entry.key, entry.value))
return lst
def PreprocessDiskToAttach(current_data_disks_list, disk_to_attach):
"""Preprocesses and validates the disk to attach.
Validates the disk to attach is not already attached to the TPU VM with
different mode or same mode and worker.
Deletes the disk from the current_data_disks_list if it is already attached
to the TPU VM with same mode but different worker.
If the disk is currently attached to the TPU VM with same mode,
joins the current worker list and the new worker list.
Args:
current_data_disks_list: the list of data disks currently attached to the
TPU VM.
disk_to_attach: the disk to attach to the TPU VM.
Raises:
AttachDiskError: if the disk is already attached to the TPU VM
with different mode.
AttachDiskError: if the disk is already attached to the TPU VM with same
mode and worker.
"""
for i, disk in enumerate(current_data_disks_list):
if disk.sourceDisk != disk_to_attach.sourceDisk:
continue
if (disk.mode != disk_to_attach.mode):
raise AttachDiskError(
'argument --attach-disk: the disk is already attached to the TPU '
'VM with different mode.'
)
if not (set(disk_to_attach.workerIds) - set(disk.workerIds)):
raise AttachDiskError(
'argument --attach-disk: the disk is already attached to '
'the same set of workers of TPU VM.'
)
disk_to_attach.workerIds = sorted(
set(disk.workerIds + disk_to_attach.workerIds))
# To avoid disk with same name appear twice in the list.
del current_data_disks_list[i]
def ValidateWorkerIdsField(args):
"""Checks that the worker are numberic strings only.
The only exception is "all" which is a special value that means all
workers. If "all" is specified return True.
Args:
args: the arguments for the update command.
Returns:
True if only one string "all" is specified in args.worker
False otherwise.
Raises:
WorkerIdsError: if the worker are not numberic strings only.
"""
if len(args.worker) == 1 and args.worker[0] == 'all':
return True
for w in args.worker:
if w == 'all' and len(args.worker) > 1:
raise WorkerIdsError(
'argument --worker',
'"all" cannot be specified with other worker.',
)
if not w.isnumeric():
raise WorkerIdsError(
'argument --worker',
'worker must be numeric strings only or '
'"all". e.g. --worker=0,1,2 or --worker=all',
)
return False
def CheckTPUVMNode(response, args):
"""Verifies that the node is a TPU VM node.
If it is not a TPU VM node, exit with an error instead.
Args:
response: response to GetNode.
args: the arguments for the list command.
Returns:
The response to GetNode if the node is TPU VM.
"""
del args
if IsTPUVMNode(response):
return response
log.err.Print('ERROR: Please use "gcloud compute tpus describe" for Cloud TPU'
' nodes that are not TPU VM.')
sys.exit(1)
def ParseBootDiskConfigurations(api_version='v2'):
"""Request hook for parsing boot disk configurations."""
def Process(unused_ref, args, request):
"""Parses configurations for boot disk."""
if not args or not args.IsKnownAndSpecified('boot_disk'):
return request
kms_key_arg_name = 'kms-key'
confidential_compute_arg_name = 'confidential-compute'
for arg_name in args.boot_disk.keys():
if arg_name not in [kms_key_arg_name, confidential_compute_arg_name]:
raise BootDiskConfigurationError(
'--boot-disk only supports arguments: {} and {}'.format(
confidential_compute_arg_name, kms_key_arg_name
)
)
tpu_messages = GetMessagesModule(version=api_version)
enable_confidential_compute = (
args.boot_disk.get(confidential_compute_arg_name, 'False').lower()
== 'true'
)
kms_key = args.boot_disk.get(kms_key_arg_name, None)
if enable_confidential_compute:
if api_version != 'v2alpha1':
raise exceptions.InvalidArgumentException(
'--boot-disk',
'confidential-compute is only available in the alpha release track.'
)
if kms_key is None:
raise BootDiskConfigurationError(
'argument --boot-disk: with confidential-compute={} '
'requires kms-key; received: {}'.format(
enable_confidential_compute, kms_key)
)
boot_disk_config_kwargs = {}
if kms_key:
customer_encryption_key = tpu_messages.CustomerEncryptionKey(
kmsKeyName=kms_key)
boot_disk_config_kwargs['customerEncryptionKey'] = customer_encryption_key
if api_version == 'v2alpha1' and enable_confidential_compute:
boot_disk_config_kwargs['enableConfidentialCompute'] = (
enable_confidential_compute)
if boot_disk_config_kwargs:
request.node.bootDiskConfig = tpu_messages.BootDiskConfig(
**boot_disk_config_kwargs)
return request
return Process
def SetImage(api_version='v2alpha1'):
"""Request hook for setting the source machine image."""
def Process(unused_ref, args, request):
"""Sets the source machine image in the request if provided."""
if args.IsSpecified('image'):
tpu_messages = GetMessagesModule(version=api_version)
if not request.node.bootDiskConfig:
request.node.bootDiskConfig = tpu_messages.BootDiskConfig()
request.node.bootDiskConfig.sourceImage = args.image
return request
return Process
def ProjectIdToProjectNumber(project_id):
"""Returns the Cloud project number associated with the `project_id`."""
crm_message_module = apis.GetMessagesModule('cloudresourcemanager', 'v1')
resource_manager = apis.GetClientInstance('cloudresourcemanager', 'v1')
req = crm_message_module.CloudresourcemanagerProjectsGetRequest(
projectId=project_id)
project = resource_manager.projects.Get(req)
return project.projectNumber
def CreateReservationName(unused_ref, args, request):
"""Request hook for creating the target reservation name.
Args:
unused_ref: ref to the service.
args: The args for this method.
request: The request to be made.
Returns:
Request with reservationName field populated.
"""
short_reservation_name_pattern = '^[a-zA-Z0-9-]+$'
full_reservation_name_pattern = 'projects/{}/locations/{}/reservations/{}'
reservation_name = None
if args.IsKnownAndSpecified('reservation') and re.match(
short_reservation_name_pattern, args.reservation
):
project_id = properties.VALUES.core.project.GetOrFail()
project_number = ProjectIdToProjectNumber(project_id)
reservation_name = full_reservation_name_pattern.format(
project_number, args.zone, args.reservation
)
if reservation_name:
request.node.schedulingConfig.reservationName = reservation_name
return request
def SetProvisioningModel(api_version):
"""Sets the provisioning model enum value."""
def Process(_, args, request):
tpu_messages = GetMessagesModule(api_version)
if args.spot:
request.node.schedulingConfig.provisioningModel = (
tpu_messages.SchedulingConfig.ProvisioningModelValueValuesEnum.SPOT
)
return request
if not args.provisioning_model:
request.node.schedulingConfig.provisioningModel = (
tpu_messages.SchedulingConfig.ProvisioningModelValueValuesEnum.STANDARD
)
return request
try:
normalized_candidate = args.provisioning_model.replace('-', '_').upper()
candidate_enum = (
tpu_messages.SchedulingConfig.ProvisioningModelValueValuesEnum(
normalized_candidate
)
)
except TypeError as e:
raise exceptions.InvalidArgumentException(
'--provisioning-model',
f'{args.provisioning_model} is not a valid provisioning model, must'
' be one of [standard, spot, reservation-bound]',
) from e
request.node.schedulingConfig.provisioningModel = candidate_enum
return request
return Process
class TPUNode(object):
"""Helper to create and modify TPU nodes."""
def __init__(self, release_track):
if release_track == base.ReleaseTrack.ALPHA:
self._api_version = 'v2alpha1'
else:
self._api_version = 'v2'
self.client = apis.GetClientInstance('tpu', self._api_version)
self.messages = apis.GetMessagesModule('tpu', self._api_version)
def GetMessages(self):
return self.messages
def Get(self, name, zone):
"""Retrieves the TPU node in the given zone."""
project = properties.VALUES.core.project.Get(required=True)
fully_qualified_node_name_ref = resources.REGISTRY.Parse(
name,
params={
'locationsId': zone,
'projectsId': project
},
collection='tpu.projects.locations.nodes',
)
request = self.messages.TpuProjectsLocationsNodesGetRequest(
name=fully_qualified_node_name_ref.RelativeName())
return self.client.projects_locations_nodes.Get(request)
def GetGuestAttributes(self, name, zone, worker_id=''):
"""Retrives the Guest Attributes for the nodes."""
project = properties.VALUES.core.project.Get(required=True)
fully_qualified_node_name_ref = resources.REGISTRY.Parse(
name,
params={
'locationsId': zone,
'projectsId': project
},
collection='tpu.projects.locations.nodes',
)
get_guest_attributes_request = self.messages.GetGuestAttributesRequest(
workerIds=[worker_id])
request = self.messages.TpuProjectsLocationsNodesGetGuestAttributesRequest(
name=fully_qualified_node_name_ref.RelativeName(),
getGuestAttributesRequest=get_guest_attributes_request)
return self.client.projects_locations_nodes.GetGuestAttributes(request)
def UpdateNode(self, name, zone, node, update_mask, poller_message):
"""Updates the TPU node in the given zone."""
project = properties.VALUES.core.project.Get(required=True)
fully_qualified_node_name_ref = resources.REGISTRY.Parse(
name,
params={
'locationsId': zone,
'projectsId': project
},
collection='tpu.projects.locations.nodes',
)
request = self.messages.TpuProjectsLocationsNodesPatchRequest(
name=fully_qualified_node_name_ref.RelativeName(),
node=node,
updateMask=update_mask)
# Call UpdateNode to start the LRO.
operation = self.client.projects_locations_nodes.Patch(request)
operation_ref = resources.REGISTRY.ParseRelativeName(
operation.name, collection='tpu.projects.locations.operations'
)
# Wait for the UpdateNode LRO to complete.
return self.WaitForOperation(operation_ref, poller_message)
def UpdateMetadataKey(self, metadata, key, value):
"""Updates a key in the TPU metadata object.
If the key does not exist, it is added.
Args:
metadata: tpu.messages.Node.MetadataValue, the TPU's metadata.
key: str, the key to be updated.
value: str, the new value for the key.
Returns:
The updated metadata.
"""
# If the metadata is empty, return a new metadata object with just the key.
if metadata is None or metadata.additionalProperties is None:
return self.messages.Node.MetadataValue(
additionalProperties=[
self.messages.Node.MetadataValue.AdditionalProperty(
key=key, value=value)])
item = None
for x in metadata.additionalProperties:
if x.key == key:
item = x
break
if item is not None:
item.value = value
else:
# The key is not in the metadata, so append it.
metadata.additionalProperties.append(
self.messages.Node.MetadataValue.AdditionalProperty(
key=key, value=value))
return metadata
def WaitForOperation(self, operation_ref, message):
operation_poller = waiter.CloudOperationPoller(
self.client.projects_locations_nodes,
self.client.projects_locations_operations)
return waiter.WaitFor(operation_poller, operation_ref, message)
class SSHPreppedNode(object):
"""Object that has all the data needed to successfully SSH into a node.
Attributes:
worker_ips: The IPs of the workers of the node.
ssh_helper: The ssh_helper used to SSH into the node.
id: The id of the node.
tpu_name: The unqualified TPU VM name.
instance_names: The name of the instances of the workers of the node.
project: The project associated with the node.
command_list: The list of the commands passed into ssh.
remainder: The remainder list of ssh_args used to pass into the SSH command.
host_key_suffixes: The host key suffixes associated with the node.
user: The user executing the SSH command.
release_track: The release track for the SSH protos (Alpha, Beta, etc.).
enable_batching: A bool indicating if the user enabled batching for the
node.
"""
def __init__(self, tpu_name, user, release_track, enable_batching):
self.tpu_name = tpu_name
self.user = user
self.release_track = release_track
self.enable_batching = enable_batching
self.worker_ips = []
self.ssh_helper = None
self.id = None
self.instance_names = []
self.project = None
self.command_list = []
self.remainder = None
self.host_key_suffixes = []
class SCPPreppedNode(SSHPreppedNode):
"""Object that has all the data needed to successfully SCP into a node.
Attributes:
srcs: The sources for SCP.
dst: The destination for SCP.
"""
def __init__(self, tpu_name, user, release_track, enable_batching, srcs, dst):
super(SCPPreppedNode, self).__init__(
tpu_name, user, release_track, enable_batching
)
self.srcs = srcs
self.dst = dst