Source code for airflow.providers.google.cloud.triggers.mlengine
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportasynciofromcollections.abcimportAsyncIterator,SequencefromtypingimportAnyfromairflow.providers.google.cloud.hooks.mlengineimportMLEngineAsyncHookfromairflow.providers.google.common.hooks.base_googleimportPROVIDE_PROJECT_IDfromairflow.triggers.baseimportBaseTrigger,TriggerEvent
[docs]classMLEngineStartTrainingJobTrigger(BaseTrigger):""" MLEngineStartTrainingJobTrigger run on the trigger worker to perform starting training job operation. :param conn_id: Reference to google cloud connection id :param job_id: The ID of the job. It will be suffixed with hash of job configuration :param project_id: Google Cloud Project where the job is running :param poll_interval: polling period in seconds to check for the status """def__init__(self,conn_id:str,job_id:str,region:str,poll_interval:float=4.0,package_uris:list[str]|None=None,training_python_module:str|None=None,training_args:list[str]|None=None,runtime_version:str|None=None,python_version:str|None=None,job_dir:str|None=None,project_id:str=PROVIDE_PROJECT_ID,labels:dict[str,str]|None=None,gcp_conn_id:str="google_cloud_default",impersonation_chain:str|Sequence[str]|None=None,):super().__init__()self.log.info("Using the connection %s .",conn_id)self.conn_id=conn_idself.job_id=job_idself._job_conn=Noneself.project_id=project_idself.region=regionself.poll_interval=poll_intervalself.runtime_version=runtime_versionself.python_version=python_versionself.job_dir=job_dirself.package_uris=package_urisself.training_python_module=training_python_moduleself.training_args=training_argsself.labels=labelsself.gcp_conn_id=gcp_conn_idself.impersonation_chain=impersonation_chain
[docs]defserialize(self)->tuple[str,dict[str,Any]]:"""Serialize MLEngineStartTrainingJobTrigger arguments and classpath."""return("airflow.providers.google.cloud.triggers.mlengine.MLEngineStartTrainingJobTrigger",{"conn_id":self.conn_id,"job_id":self.job_id,"poll_interval":self.poll_interval,"region":self.region,"project_id":self.project_id,"runtime_version":self.runtime_version,"python_version":self.python_version,"job_dir":self.job_dir,"package_uris":self.package_uris,"training_python_module":self.training_python_module,"training_args":self.training_args,"labels":self.labels,},)
[docs]asyncdefrun(self)->AsyncIterator[TriggerEvent]:# type: ignore[override]"""Get current job execution status and yields a TriggerEvent."""hook=self._get_async_hook()try:whileTrue:# Poll for job execution statusresponse_from_hook=awaithook.get_job_status(job_id=self.job_id,project_id=self.project_id)ifresponse_from_hook=="success":yieldTriggerEvent({"job_id":self.job_id,"status":"success","message":"Job completed",})returnelifresponse_from_hook=="pending":self.log.info("Job is still running...")self.log.info("Sleeping for %s seconds.",self.poll_interval)awaitasyncio.sleep(self.poll_interval)else:yieldTriggerEvent({"status":"error","message":response_from_hook})returnexceptExceptionase:self.log.exception("Exception occurred while checking for query completion")yieldTriggerEvent({"status":"error","message":str(e)})