File: //snap/google-cloud-cli/394/lib/googlecloudsdk/api_lib/ml/speech/client.py
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Speech-to-text V2 client."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import contextlib
import os
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.storage import storage_util
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import exceptions
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.command_lib.ml.speech import flag_validations
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
from googlecloudsdk.core.util import files
from six.moves import urllib
_API_NAME = 'speech'
_API_VERSION = 'v2'
@contextlib.contextmanager
def _OverrideEndpoint(override):
"""Context manager to override an API's endpoint overrides for a while."""
endpoint_property = getattr(
properties.VALUES.api_endpoint_overrides, _API_NAME
)
old_endpoint = endpoint_property.Get()
try:
endpoint_property.Set(override)
yield
finally:
endpoint_property.Set(old_endpoint)
class SpeechV2Client(object):
"""Speech V2 API client wrappers."""
def __init__(self):
client_class = apis.GetClientClass(_API_NAME, _API_VERSION)
self._net_loc = urllib.parse.urlsplit(client_class.BASE_URL).netloc
messages = apis.GetMessagesModule(_API_NAME, _API_VERSION)
self._resource_parser = resources.Registry()
self._resource_parser.RegisterApiByName(_API_NAME, _API_VERSION)
self._encoding_to_message = {
'LINEAR16': (
messages.ExplicitDecodingConfig.EncodingValueValuesEnum.LINEAR16
),
'MULAW': messages.ExplicitDecodingConfig.EncodingValueValuesEnum.MULAW,
'ALAW': messages.ExplicitDecodingConfig.EncodingValueValuesEnum.ALAW,
}
self._messages = messages
def _GetClientForLocation(self, location):
with _OverrideEndpoint('https://{}-{}/'.format(location, self._net_loc)):
return apis.GetClientInstance(_API_NAME, _API_VERSION)
def _RecognizerServiceForLocation(self, location):
return self._GetClientForLocation(location).projects_locations_recognizers
def _OperationsServiceForLocation(self, location):
return self._GetClientForLocation(location).projects_locations_operations
def _LocationsServiceForLocation(self, location):
return self._GetClientForLocation(location).projects_locations
def CreateRecognizer(
self,
resource,
display_name,
model,
language_codes,
recognition_config,
):
"""Call API CreateRecognizer method with provided arguments."""
recognizer = self._messages.Recognizer(displayName=display_name)
recognizer.model = model
recognizer.languageCodes = language_codes
recognizer.defaultRecognitionConfig = recognition_config
request = self._messages.SpeechProjectsLocationsRecognizersCreateRequest(
parent=resource.Parent(
parent_collection='speech.projects.locations'
).RelativeName(),
recognizerId=resource.Name(),
recognizer=recognizer,
)
return self._RecognizerServiceForLocation(
location=resource.Parent().Name()
).Create(request)
def GetRecognizer(self, resource):
request = self._messages.SpeechProjectsLocationsRecognizersGetRequest(
name=resource.RelativeName()
)
return self._RecognizerServiceForLocation(
location=resource.Parent().Name()
).Get(request)
def DeleteRecognizer(self, resource):
request = self._messages.SpeechProjectsLocationsRecognizersDeleteRequest(
name=resource.RelativeName()
)
return self._RecognizerServiceForLocation(
location=resource.Parent().Name()
).Delete(request)
def ListRecognizers(self, location_resource, limit=None, page_size=None):
request = self._messages.SpeechProjectsLocationsRecognizersListRequest(
parent=location_resource.RelativeName()
)
if page_size:
request.page_size = page_size
return list_pager.YieldFromList(
self._RecognizerServiceForLocation(location_resource.Name()),
request,
limit=limit,
batch_size_attribute='pageSize',
batch_size=page_size,
field='recognizers',
)
def UpdateRecognizer(
self,
resource,
display_name,
model,
language_codes,
recognition_config,
update_mask,
):
"""Call API UpdateRecognizer method with provided arguments."""
recognizer = self._messages.Recognizer()
if display_name is not None:
recognizer.displayName = display_name
update_mask.append('display_name')
if model is not None:
recognizer.model = model
update_mask.append('model')
if language_codes is not None:
recognizer.languageCodes = language_codes
update_mask.append('language_codes')
recognizer.defaultRecognitionConfig = recognition_config
request = self._messages.SpeechProjectsLocationsRecognizersPatchRequest(
name=resource.RelativeName(),
recognizer=recognizer,
updateMask=','.join(update_mask),
)
return self._RecognizerServiceForLocation(
location=resource.Parent().Name()
).Patch(request)
def RunShort(
self,
resource,
audio,
recognition_config,
update_mask,
):
"""Call API Recognize method with provided arguments."""
recognize_req = self._messages.RecognizeRequest()
if os.path.isfile(audio):
recognize_req.content = files.ReadBinaryFileContents(audio)
elif storage_util.ObjectReference.IsStorageUrl(audio):
recognize_req.uri = audio
recognizer_service = self._RecognizerServiceForLocation(
location=resource.Parent().Name()
)
recognize_req.config = recognition_config
recognize_req.configMask = ','.join(update_mask)
request = self._messages.SpeechProjectsLocationsRecognizersRecognizeRequest(
recognizeRequest=recognize_req,
recognizer=resource.RelativeName(),
)
return recognizer_service.Recognize(request)
def RunBatch(
self,
resource,
audio,
recognition_config,
update_mask,
):
"""Call API Recognize method with provided arguments in batch mode."""
batch_audio_metadata = self._messages.BatchRecognizeFileMetadata(uri=audio)
recognize_req = self._messages.BatchRecognizeRequest(
recognizer=resource.RelativeName(),
files=[batch_audio_metadata],
)
recognizer_service = self._RecognizerServiceForLocation(
location=resource.Parent().Name()
)
recognize_req.config = recognition_config
recognize_req.recognitionOutputConfig = (
self._messages.RecognitionOutputConfig(
inlineResponseConfig=self._messages.InlineOutputConfig()
)
)
recognize_req.configMask = ','.join(update_mask)
return recognizer_service.BatchRecognize(recognize_req)
def GetOperationRef(self, operation):
"""Converts an Operation to a Resource."""
return self._resource_parser.ParseRelativeName(
operation.name, 'speech.projects.locations.operations'
)
def WaitForRecognizerOperation(self, location, operation_ref, message):
"""Waits for a Recognizer operation to complete.
Polls the Speech Operation service until the operation completes, fails, or
max_wait_ms elapses.
Args:
location: The location of the resource.
operation_ref: A Resource created by GetOperationRef describing the
Operation.
message: The message to display to the user while they wait.
Returns:
An Endpoint entity.
"""
poller = waiter.CloudOperationPoller(
result_service=self._RecognizerServiceForLocation(location),
operation_service=self._OperationsServiceForLocation(location),
)
return waiter.WaitFor(
poller=poller,
operation_ref=operation_ref,
message=message,
pre_start_sleep_ms=100,
max_wait_ms=20000,
)
def WaitForBatchRecognizeOperation(self, location, operation_ref, message):
"""Waits for a Batch Recognize operation to complete.
Polls the Speech Operation service until the operation completes, fails, or
max_wait_ms elapses.
Args:
location: The location of the resource.
operation_ref: A Resource created by GetOperationRef describing the
Operation.
message: The message to display to the user while they wait.
Returns:
An Endpoint entity.
"""
poller = waiter.CloudOperationPollerNoResources(
self._OperationsServiceForLocation(location),
lambda x: x,
)
return waiter.WaitFor(
poller,
operation_ref,
message=message,
wait_ceiling_ms=86400000,
)
def GetLocation(self, location_resource):
request = self._messages.SpeechProjectsLocationsGetRequest(
name=location_resource.RelativeName()
)
return self._LocationsServiceForLocation(
location=location_resource.Name()
).Get(request)
def ListLocations(self, filter_str=None, limit=None, page_size=None):
request = self._messages.SpeechProjectsLocationsListRequest(
name=properties.VALUES.core.project.Get()
)
if filter_str:
request.filter = filter_str
if page_size:
request.page_size = page_size
return list_pager.YieldFromList(
self._LocationsServiceForLocation('global'),
request,
limit=limit,
batch_size_attribute='pageSize',
batch_size=page_size,
field='locations',
)
def InitializeRecognitionConfig(
self, model=None, language_codes=None, update_mask=None
):
"""creates a recognition config object and initializes it with model and language codes."""
recognition_config = self._messages.RecognitionConfig()
if model is not None:
recognition_config.model = model
if language_codes is not None:
recognition_config.languageCodes = language_codes
if update_mask is None:
return recognition_config, update_mask
if model is not None:
update_mask.append('model')
if language_codes is not None:
update_mask.append('language_codes')
return recognition_config, update_mask
def InitializeDecodingConfigFromArgs(
self,
recognition_config,
args,
default_to_auto_decoding_config=False,
update_mask=None,
):
return self._InitializeDecodingConfigRecognizerCommand(
recognition_config,
args.encoding,
args.sample_rate,
args.audio_channel_count,
default_to_auto_decoding_config=default_to_auto_decoding_config,
update_mask=update_mask,
)
def _InitializeDecodingConfigRecognizerCommand(
self,
recognition_config,
encoding,
sample_rate,
audio_channel_count,
default_to_auto_decoding_config=False,
update_mask=None,
):
"""Initializes encoding type based on auto (or explicit decoding option), sample rate and audio channel count."""
if encoding is not None:
if encoding == 'AUTO':
recognition_config.autoDecodingConfig = (
self._messages.AutoDetectDecodingConfig()
)
elif encoding in flag_validations.EXPLICIT_ENCODING_OPTIONS:
recognition_config.explicitDecodingConfig = (
self._messages.ExplicitDecodingConfig()
)
recognition_config.explicitDecodingConfig.encoding = (
self._encoding_to_message[encoding]
)
if sample_rate is not None:
recognition_config.explicitDecodingConfig.sampleRateHertz = (
sample_rate
)
if audio_channel_count is not None:
recognition_config.explicitDecodingConfig.audioChannelCount = (
audio_channel_count
)
else:
raise exceptions.InvalidArgumentException(
'--encoding',
'[--encoding] must be set to LINEAR16, MULAW, ALAW, or AUTO.',
)
elif default_to_auto_decoding_config:
recognition_config.autoDecodingConfig = (
self._messages.AutoDetectDecodingConfig()
)
if update_mask is None:
return recognition_config, update_mask
if encoding == 'AUTO':
update_mask.append('auto_decoding_config')
elif encoding in flag_validations.EXPLICIT_ENCODING_OPTIONS:
update_mask.append('explicit_decoding_config')
elif default_to_auto_decoding_config:
update_mask.append('auto_decoding_config')
if sample_rate is not None:
if recognition_config.explicitDecodingConfig is None:
recognition_config.explicitDecodingConfig = (
self._messages.ExplicitDecodingConfig()
)
recognition_config.explicitDecodingConfig.sampleRateHertz = sample_rate
update_mask.append('explicit_decoding_config.sample_rate_hertz')
if audio_channel_count is not None:
if recognition_config.explicitDecodingConfig is None:
recognition_config.explicitDecodingConfig = (
self._messages.ExplicitDecodingConfig()
)
recognition_config.explicitDecodingConfig.audioChannelCount = (
audio_channel_count
)
update_mask.append('explicit_decoding_config.audio_channel_count')
return recognition_config, update_mask
def InitializeAdaptationConfigFromArgs(
self,
args,
update_mask=None,
):
"""Initializes PhraseSets based on hints."""
return self._InitializeAdaptationConfigRecognizeRequest(
args.hint_phrases, args.hint_phrase_sets, args.hint_boost, update_mask
)
def _InitializeAdaptationConfigRecognizeRequest(
self, hint_phrases, hint_phrase_sets, hint_boost=5.0, update_mask=None
):
"""Initializes PhraseSets based on phrases and phrase sets."""
speech_adaptation_phrase_sets = []
if hint_phrases:
inline_phrase_set = self._messages.PhraseSet(
phrases=[
self._messages.Phrase(value=hint_phrase, boost=5.0)
for hint_phrase in hint_phrases
],
boost=hint_boost,
)
inline_adaptation_phrase_set = self._messages.AdaptationPhraseSet(
inlinePhraseSet=inline_phrase_set
)
speech_adaptation_phrase_sets.append(inline_adaptation_phrase_set)
if hint_phrase_sets:
for hint_phrase_set in hint_phrase_sets:
adaptation_phrase_set = self._messages.AdaptationPhraseSet(
phraseSet=hint_phrase_set
)
speech_adaptation_phrase_sets.append(adaptation_phrase_set)
speech_adaptation_config = self._messages.SpeechAdaptation(
phraseSets=speech_adaptation_phrase_sets
)
if update_mask is not None:
update_mask.append('adaptation')
return speech_adaptation_config, update_mask
def InitializeASRFeaturesFromArgs(
self,
args,
update_mask=None,
):
"""Collects features from the provided arguments."""
features_config = self._messages.RecognitionFeatures()
inner_update_mask = []
if args.profanity_filter is not None:
features_config.profanityFilter = args.profanity_filter
inner_update_mask.append('features.profanityFilter')
if args.enable_word_time_offsets is not None:
features_config.enableWordTimeOffsets = args.enable_word_time_offsets
inner_update_mask.append('features.enableWordTimeOffsets')
if args.enable_word_confidence is not None:
features_config.enableWordConfidence = args.enable_word_confidence
inner_update_mask.append('features.enableWordConfidence')
if args.enable_automatic_punctuation is not None:
features_config.enableAutomaticPunctuation = (
args.enable_automatic_punctuation
)
inner_update_mask.append('features.enableAutomaticPunctuation')
if args.enable_spoken_punctuation is not None:
features_config.enableSpokenPunctuation = args.enable_spoken_punctuation
inner_update_mask.append('features.enableSpokenPunctuation')
if args.enable_spoken_emojis is not None:
features_config.enableSpokenEmojis = args.enable_spoken_emojis
inner_update_mask.append('features.enableSpokenEmojis')
if (
args.min_speaker_count is not None
and args.max_speaker_count is not None
):
features_config.diarizationConfig = (
self._messages.SpeakerDiarizationConfig(
minSpeakerCount=args.min_speaker_count,
maxSpeakerCount=args.max_speaker_count,
)
)
inner_update_mask.append('features.diarizationConfig.minSpeakerCount')
inner_update_mask.append('features.diarizationConfig.maxSpeakerCount')
if args.separate_channel_recognition:
features_config.multiChannelMode = (
self._messages.RecognitionFeatures.MultiChannelModeValueValuesEnum.SEPARATE_RECOGNITION_PER_CHANNEL
)
inner_update_mask.append('features.multiChannelMode')
elif args.separate_channel_recognition is not None:
features_config.multiChannelMode = (
self._messages.RecognitionFeatures.MultiChannelModeValueValuesEnum.MULTI_CHANNEL_MODE_UNSPECIFIED
)
inner_update_mask.append('features.multiChannelMode')
if args.max_alternatives is not None:
features_config.maxAlternatives = args.max_alternatives
inner_update_mask.append('features.maxAlternatives')
if update_mask is not None:
update_mask.extend(inner_update_mask)
return features_config, update_mask