Source code for airflow.providers.git.hooks.git
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import contextlib
import json
import logging
import os
import tempfile
from typing import Any
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
[docs]
log = logging.getLogger(__name__)
[docs]
class GitHook(BaseHook):
"""
Hook for git repositories.
:param git_conn_id: Connection ID for SSH connection to the repository
"""
[docs]
conn_name_attr = "git_conn_id"
[docs]
default_conn_name = "git_default"
@classmethod
[docs]
def get_ui_field_behaviour(cls) -> dict[str, Any]:
return {
"hidden_fields": ["schema"],
"relabeling": {
"login": "Username",
"host": "Repository URL",
"password": "Access Token (optional)",
},
"placeholders": {
"extra": json.dumps(
{
"key_file": "optional/path/to/keyfile",
"private_key": "optional inline private key",
}
)
},
}
def __init__(
self, git_conn_id: str = "git_default", repo_url: str | None = None, *args, **kwargs
) -> None:
super().__init__()
connection = self.get_connection(git_conn_id)
[docs]
self.repo_url = repo_url or connection.host
[docs]
self.auth_token = connection.password
[docs]
self.private_key = connection.extra_dejson.get("private_key")
[docs]
self.key_file = connection.extra_dejson.get("key_file")
[docs]
self.strict_host_key_checking = connection.extra_dejson.get("strict_host_key_checking", "no")
[docs]
self.env: dict[str, str] = {}
if self.key_file and self.private_key:
raise AirflowException("Both 'key_file' and 'private_key' cannot be provided at the same time")
self._process_git_auth_url()
def _build_ssh_command(self, key_path: str) -> str:
return (
f"ssh -i {key_path} "
f"-o IdentitiesOnly=yes "
f"-o StrictHostKeyChecking={self.strict_host_key_checking}"
)
def _process_git_auth_url(self):
if not isinstance(self.repo_url, str):
return
if self.auth_token and self.repo_url.startswith("https://"):
self.repo_url = self.repo_url.replace("https://", f"https://{self.auth_token}@")
elif not self.repo_url.startswith("git@") or not self.repo_url.startswith("https://"):
self.repo_url = os.path.expanduser(self.repo_url)
[docs]
def set_git_env(self, key: str) -> None:
self.env["GIT_SSH_COMMAND"] = self._build_ssh_command(key)
@contextlib.contextmanager