Source code for airflow.providers.git.bundles.git

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import os
from pathlib import Path
from urllib.parse import urlparse

import structlog
from git import Repo
from git.exc import BadName, GitCommandError, NoSuchPathError

from airflow.dag_processing.bundles.base import (
    BaseDagBundle,
)
from airflow.exceptions import AirflowException
from airflow.providers.git.hooks.git import GitHook

[docs] log = structlog.get_logger(__name__)
[docs] class GitDagBundle(BaseDagBundle): """ git DAG bundle - exposes a git repository as a DAG bundle. Instead of cloning the repository every time, we clone the repository once into a bare repo from the source and then do a clone for each version from there. :param tracking_ref: Branch or tag for this DAG bundle :param subdir: Subdirectory within the repository where the DAGs are stored (Optional) :param git_conn_id: Connection ID for SSH/token based connection to the repository (Optional) :param repo_url: Explicit Git repository URL to override the connection's host. (Optional) """
[docs] supports_versioning = True
def __init__( self, *, tracking_ref: str, subdir: str | None = None, git_conn_id: str = "git_default", repo_url: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs)
[docs] self.tracking_ref = tracking_ref
[docs] self.subdir = subdir
[docs] self.bare_repo_path = self.base_dir / "bare"
if self.version: self.repo_path = self.versions_dir / self.version else: self.repo_path = self.base_dir / "tracking_repo"
[docs] self.git_conn_id = git_conn_id
[docs] self.repo_url = repo_url
self._log = log.bind( bundle_name=self.name, version=self.version, bare_repo_path=self.bare_repo_path, repo_path=self.repo_path, versions_path=self.versions_dir, git_conn_id=self.git_conn_id, repo_url=self.repo_url, ) self._log.debug("bundle configured") try: self.hook = GitHook(git_conn_id=self.git_conn_id, repo_url=self.repo_url) self.repo_url = self.hook.repo_url self._log.debug("repo_url updated from hook", repo_url=self.repo_url) except AirflowException as e: self._log.warning("Could not create GitHook", conn_id=self.git_conn_id, exc=e) def _initialize(self): with self.lock(): with self.hook.configure_hook_env(): self._clone_bare_repo_if_required() self._ensure_version_in_bare_repo() self._clone_repo_if_required() self.repo.git.checkout(self.tracking_ref) self._log.debug("bundle initialize", version=self.version) if self.version: if not self._has_version(self.repo, self.version): self.repo.remotes.origin.fetch() self.repo.head.set_reference(str(self.repo.commit(self.version))) self.repo.head.reset(index=True, working_tree=True) else: self.refresh()
[docs] def initialize(self) -> None: if not self.repo_url: raise AirflowException(f"Connection {self.git_conn_id} doesn't have a host url") self._initialize() super().initialize()
def _clone_repo_if_required(self) -> None: if not os.path.exists(self.repo_path): self._log.info("Cloning repository", repo_path=self.repo_path, bare_repo_path=self.bare_repo_path) try: Repo.clone_from( url=self.bare_repo_path, to_path=self.repo_path, ) except NoSuchPathError as e: # Protection should the bare repo be removed manually raise AirflowException("Repository path: %s not found", self.bare_repo_path) from e else: self._log.debug("repo exists", repo_path=self.repo_path) self.repo = Repo(self.repo_path) def _clone_bare_repo_if_required(self) -> None: if not self.repo_url: raise AirflowException(f"Connection {self.git_conn_id} doesn't have a host url") if not os.path.exists(self.bare_repo_path): self._log.info("Cloning bare repository", bare_repo_path=self.bare_repo_path) try: Repo.clone_from( url=self.repo_url, to_path=self.bare_repo_path, bare=True, env=self.hook.env, ) except GitCommandError as e: raise AirflowException("Error cloning repository") from e self.bare_repo = Repo(self.bare_repo_path) def _ensure_version_in_bare_repo(self) -> None: if not self.version: return if not self._has_version(self.bare_repo, self.version): self._fetch_bare_repo() if not self._has_version(self.bare_repo, self.version): raise AirflowException(f"Version {self.version} not found in the repository")
[docs] def __repr__(self): return ( f"<GitDagBundle(" f"name={self.name!r}, " f"tracking_ref={self.tracking_ref!r}, " f"subdir={self.subdir!r}, " f"version={self.version!r}" f")>" )
[docs] def get_current_version(self) -> str: return self.repo.head.commit.hexsha
@property
[docs] def path(self) -> Path: if self.subdir: return self.repo_path / self.subdir return self.repo_path
@staticmethod def _has_version(repo: Repo, version: str) -> bool: try: repo.commit(version) return True except (BadName, ValueError): return False def _fetch_bare_repo(self): refspecs = ["+refs/heads/*:refs/heads/*", "+refs/tags/*:refs/tags/*"] if self.hook.env: with self.bare_repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")): self.bare_repo.remotes.origin.fetch(refspecs) else: self.bare_repo.remotes.origin.fetch(refspecs)
[docs] def refresh(self) -> None: if self.version: raise AirflowException("Refreshing a specific version is not supported") with self.lock(): with self.hook.configure_hook_env(): self._fetch_bare_repo() self.repo.remotes.origin.fetch( ["+refs/heads/*:refs/remotes/origin/*", "+refs/tags/*:refs/tags/*"] ) remote_branch = f"origin/{self.tracking_ref}" if remote_branch in [ref.name for ref in self.repo.remotes.origin.refs]: target = remote_branch else: target = self.tracking_ref self.repo.head.reset(target, index=True, working_tree=True)
@staticmethod def _convert_git_ssh_url_to_https(url: str) -> str: if not url.startswith("git@"): raise ValueError(f"Invalid git SSH URL: {url}") parts = url.split(":") domain = parts[0].replace("git@", "https://") repo_path = parts[1].replace(".git", "") return f"{domain}/{repo_path}"
[docs] def view_url(self, version: str | None = None) -> str | None: if not version: return None url = self.repo_url if not url: return None if url.startswith("git@"): url = self._convert_git_ssh_url_to_https(url) if url.endswith(".git"): url = url[:-4] parsed_url = urlparse(url) host = parsed_url.hostname if not host: return None if parsed_url.username or parsed_url.password: new_netloc = host if parsed_url.port: new_netloc += f":{parsed_url.port}" url = parsed_url._replace(netloc=new_netloc).geturl() host_patterns = { "github.com": f"{url}/tree/{version}", "gitlab.com": f"{url}/-/tree/{version}", "bitbucket.org": f"{url}/src/{version}", } for allowed_host, template in host_patterns.items(): if host == allowed_host or host.endswith(f".{allowed_host}"): return template return None

Was this entry helpful?