Source code for airflow.providers.microsoft.azure.fs.msgraph
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from airflow.providers.microsoft.azure.utils import get_field
from airflow.providers.microsoft.azure.version_compat import BaseHook
if TYPE_CHECKING:
from fsspec import AbstractFileSystem
[docs]
schemes = ["msgraph", "sharepoint", "onedrive", "msgd"]
[docs]
def get_fs(conn_id: str | None, storage_options: dict[str, Any] | None = None) -> AbstractFileSystem:
from msgraphfs import MSGDriveFS
if conn_id is None:
return MSGDriveFS({})
conn = BaseHook.get_connection(conn_id)
extras = conn.extra_dejson
conn_type = conn.conn_type or "msgraph"
options: dict[str, Any] = {}
# Get authentication parameters with fallback handling
client_id = conn.login or get_field(
conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="client_id"
)
client_secret = conn.password or get_field(
conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="client_secret"
)
tenant_id = conn.host or get_field(
conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="tenant_id"
)
if client_id:
options["client_id"] = client_id
if client_secret:
options["client_secret"] = client_secret
if tenant_id:
options["tenant_id"] = tenant_id
# Process additional fields from extras
fields = [
"drive_id",
"scope",
"token_endpoint",
"redirect_uri",
"token_endpoint_auth_method",
"code_challenge_method",
"update_token",
"username",
"password",
]
for field in fields:
value = get_field(conn_id=conn_id, conn_type=conn_type, extras=extras, field_name=field)
if value is not None:
if value == "":
options.pop(field, "")
else:
options[field] = value
# Update with storage options
options.update(storage_options or {})
# Create oauth2 client parameters if authentication is provided
oauth2_client_params = {}
if options.get("client_id") and options.get("client_secret") and options.get("tenant_id"):
oauth2_client_params = {
"client_id": options["client_id"],
"client_secret": options["client_secret"],
"tenant_id": options["tenant_id"],
}
# Add additional oauth2 parameters supported by authlib
oauth2_params = [
"scope",
"token_endpoint",
"redirect_uri",
"token_endpoint_auth_method",
"code_challenge_method",
"update_token",
"username",
"password",
]
for param in oauth2_params:
if param in options:
oauth2_client_params[param] = options[param]
# Determine which filesystem to return based on drive_id
drive_id = options.get("drive_id")
return MSGDriveFS(drive_id=drive_id, oauth2_client_params=oauth2_client_params)