Airflow Summit 2025 is coming October 07-09. Register now for early bird ticket!

Source code for airflow.providers.git.bundles.git

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import os
from contextlib import nullcontext
from pathlib import Path
from urllib.parse import urlparse

import structlog
from git import Repo
from git.exc import BadName, GitCommandError, NoSuchPathError

from airflow.dag_processing.bundles.base import (
    BaseDagBundle,
)
from airflow.exceptions import AirflowException
from airflow.providers.git.hooks.git import GitHook

[docs] log = structlog.get_logger(__name__)
[docs] class GitDagBundle(BaseDagBundle): """ git DAG bundle - exposes a git repository as a DAG bundle. Instead of cloning the repository every time, we clone the repository once into a bare repo from the source and then do a clone for each version from there. :param tracking_ref: Branch or tag for this DAG bundle :param subdir: Subdirectory within the repository where the DAGs are stored (Optional) :param git_conn_id: Connection ID for SSH/token based connection to the repository (Optional) :param repo_url: Explicit Git repository URL to override the connection's host. (Optional) """
[docs] supports_versioning = True
def __init__( self, *, tracking_ref: str, subdir: str | None = None, git_conn_id: str | None = None, repo_url: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs)
[docs] self.tracking_ref = tracking_ref
[docs] self.subdir = subdir
[docs] self.bare_repo_path = self.base_dir / "bare"
if self.version: self.repo_path = self.versions_dir / self.version else: self.repo_path = self.base_dir / "tracking_repo"
[docs] self.git_conn_id = git_conn_id
[docs] self.repo_url = repo_url
self._log = log.bind( bundle_name=self.name, version=self.version, bare_repo_path=self.bare_repo_path, repo_path=self.repo_path, versions_path=self.versions_dir, git_conn_id=self.git_conn_id, ) self._log.debug("bundle configured")
[docs] self.hook: GitHook | None = None
if not repo_url: if not git_conn_id: self._log.debug("Neither git_conn_id nor repo_url provided; loading 'git_default'") git_conn_id = "git_default" try: self.hook = GitHook(git_conn_id=git_conn_id) except AirflowException as e: self._log.warning("Could not create GitHook", conn_id=git_conn_id, exc=e) else: self.repo_url = self.hook.repo_url self._log.debug("repo_url updated from hook") def _initialize(self): with self.lock(): cm = self.hook.configure_hook_env() if self.hook else nullcontext() with cm: self._clone_bare_repo_if_required() self._ensure_version_in_bare_repo() self._clone_repo_if_required() self.repo.git.checkout(self.tracking_ref) self._log.debug("bundle initialize", version=self.version) if self.version: if not self._has_version(self.repo, self.version): self.repo.remotes.origin.fetch() self.repo.head.set_reference(str(self.repo.commit(self.version))) self.repo.head.reset(index=True, working_tree=True) else: self.refresh()
[docs] def initialize(self) -> None: if not self.repo_url: raise AirflowException(f"Connection {self.git_conn_id} doesn't have a host url") self._initialize() super().initialize()
def _clone_repo_if_required(self) -> None: if not os.path.exists(self.repo_path): self._log.info("Cloning repository", repo_path=self.repo_path, bare_repo_path=self.bare_repo_path) try: Repo.clone_from( url=self.bare_repo_path, to_path=self.repo_path, ) except NoSuchPathError as e: # Protection should the bare repo be removed manually raise AirflowException("Repository path: %s not found", self.bare_repo_path) from e else: self._log.debug("repo exists", repo_path=self.repo_path) self.repo = Repo(self.repo_path) def _clone_bare_repo_if_required(self) -> None: if not self.repo_url: raise AirflowException(f"Connection {self.git_conn_id} doesn't have a host url") if not os.path.exists(self.bare_repo_path): self._log.info("Cloning bare repository", bare_repo_path=self.bare_repo_path) try: Repo.clone_from( url=self.repo_url, to_path=self.bare_repo_path, bare=True, env=self.hook.env if self.hook else None, ) except GitCommandError as e: raise AirflowException("Error cloning repository") from e self.bare_repo = Repo(self.bare_repo_path) def _ensure_version_in_bare_repo(self) -> None: if not self.version: return if not self._has_version(self.bare_repo, self.version): self._fetch_bare_repo() if not self._has_version(self.bare_repo, self.version): raise AirflowException(f"Version {self.version} not found in the repository")
[docs] def __repr__(self): return ( f"<GitDagBundle(" f"name={self.name!r}, " f"tracking_ref={self.tracking_ref!r}, " f"subdir={self.subdir!r}, " f"version={self.version!r}" f")>" )
[docs] def get_current_version(self) -> str: return self.repo.head.commit.hexsha
@property
[docs] def path(self) -> Path: if self.subdir: return self.repo_path / self.subdir return self.repo_path
@staticmethod def _has_version(repo: Repo, version: str) -> bool: try: repo.commit(version) return True except (BadName, ValueError): return False def _fetch_bare_repo(self): refspecs = ["+refs/heads/*:refs/heads/*", "+refs/tags/*:refs/tags/*"] cm = nullcontext() if self.hook and (cmd := self.hook.env.get("GIT_SSH_COMMAND")): cm = self.bare_repo.git.custom_environment(GIT_SSH_COMMAND=cmd) with cm: self.bare_repo.remotes.origin.fetch(refspecs)
[docs] def refresh(self) -> None: if self.version: raise AirflowException("Refreshing a specific version is not supported") with self.lock(): cm = self.hook.configure_hook_env() if self.hook else nullcontext() with cm: self._fetch_bare_repo() self.repo.remotes.origin.fetch( ["+refs/heads/*:refs/remotes/origin/*", "+refs/tags/*:refs/tags/*"] ) remote_branch = f"origin/{self.tracking_ref}" if remote_branch in [ref.name for ref in self.repo.remotes.origin.refs]: target = remote_branch else: target = self.tracking_ref self.repo.head.reset(target, index=True, working_tree=True)
@staticmethod def _convert_git_ssh_url_to_https(url: str) -> str: if not url.startswith("git@"): raise ValueError(f"Invalid git SSH URL: {url}") parts = url.split(":") domain = parts[0].replace("git@", "https://") repo_path = parts[1].replace(".git", "") return f"{domain}/{repo_path}"
[docs] def view_url(self, version: str | None = None) -> str | None: if not version: return None url = self.repo_url if not url: return None if url.startswith("git@"): url = self._convert_git_ssh_url_to_https(url) if url.endswith(".git"): url = url[:-4] parsed_url = urlparse(url) host = parsed_url.hostname if not host: return None if parsed_url.username or parsed_url.password: new_netloc = host if parsed_url.port: new_netloc += f":{parsed_url.port}" url = parsed_url._replace(netloc=new_netloc).geturl() host_patterns = { "github.com": f"{url}/tree/{version}", "gitlab.com": f"{url}/-/tree/{version}", "bitbucket.org": f"{url}/src/{version}", } if self.subdir: host_patterns = {k: f"{v}/{self.subdir}" for k, v in host_patterns.items()} for allowed_host, template in host_patterns.items(): if host == allowed_host or host.endswith(f".{allowed_host}"): return template return None

Was this entry helpful?