Airflow Summit 2025 is coming October 07-09. Register now for early bird ticket!

Source code for tests.system.amazon.aws.example_bedrock_retrieve_and_generate

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import json
import logging
import os.path
import tempfile
from datetime import datetime
from time import sleep
from typing import TYPE_CHECKING
from urllib.request import urlretrieve

import boto3
from botocore.exceptions import ClientError
from opensearchpy import (
    AuthorizationException,
    AWSV4SignerAuth,
    OpenSearch,
    RequestsHttpConnection,
)

from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook
from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.hooks.sts import StsHook
from airflow.providers.amazon.aws.operators.bedrock import (
    BedrockCreateDataSourceOperator,
    BedrockCreateKnowledgeBaseOperator,
    BedrockIngestDataOperator,
    BedrockInvokeModelOperator,
    BedrockRaGOperator,
    BedrockRetrieveOperator,
)
from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, S3DeleteBucketOperator
from airflow.providers.amazon.aws.sensors.bedrock import (
    BedrockIngestionJobSensor,
    BedrockKnowledgeBaseActiveSensor,
)
from airflow.providers.amazon.aws.sensors.opensearch_serverless import (
    OpenSearchServerlessCollectionActiveSensor,
)
from airflow.providers.amazon.aws.utils import get_botocore_version
from airflow.providers.standard.operators.empty import EmptyOperator

from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if TYPE_CHECKING:
    from airflow.decorators import task, task_group
    from airflow.models.baseoperator import chain
    from airflow.models.dag import DAG
    from airflow.sdk import Label
else:
    if AIRFLOW_V_3_0_PLUS:
        from airflow.sdk import DAG, Label, chain, task, task_group
    else:
        # Airflow 2.10 compat
        from airflow.decorators import task, task_group
        from airflow.models.baseoperator import chain
        from airflow.models.dag import DAG
        from airflow.sdk import Label
from airflow.utils.trigger_rule import TriggerRule

from system.amazon.aws.utils import SystemTestContextBuilder

#######################################################################
# NOTE:
#   Access to the following foundation model must be requested via
#   the Amazon Bedrock console and may take up to 24 hours to apply:
#######################################################################

[docs] CLAUDE_MODEL_ID = "anthropic.claude-v2"
[docs] TITAN_MODEL_ID = "amazon.titan-embed-text-v1"
# Externally fetched variables:
[docs] ROLE_ARN_KEY = "ROLE_ARN"
[docs] sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
[docs] DAG_ID = "example_bedrock_knowledge_base"
[docs] log = logging.getLogger(__name__)
@task_group
[docs] def external_sources_rag_group(): """External Sources were added in boto 1.34.90, skip this operator if the version is below that.""" # [START howto_operator_bedrock_external_sources_rag] external_sources_rag = BedrockRaGOperator( task_id="external_sources_rag", input="Who was the CEO of Amazon in 2022?", source_type="EXTERNAL_SOURCES", model_arn=f"arn:aws:bedrock:{region_name}::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0", sources=[ { "sourceType": "S3", "s3Location": {"uri": f"s3://{bucket_name}/AMZN-2022-Shareholder-Letter.pdf"}, } ], ) # [END howto_operator_bedrock_external_sources_rag] @task.branch def run_or_skip(): log.info("Found botocore version %s.", botocore_version := get_botocore_version()) return end_workflow.task_id if botocore_version < (1, 34, 90) else external_sources_rag.task_id run_or_skip = run_or_skip() end_workflow = EmptyOperator(task_id="end_workflow", trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS) chain(run_or_skip, Label("Boto version does not support External Sources"), end_workflow) chain(run_or_skip, external_sources_rag, end_workflow)
@task
[docs] def create_opensearch_policies(bedrock_role_arn: str, collection_name: str, policy_name_suffix: str) -> None: """ Create security, network and data access policies within Amazon OpenSearch Serverless. :param bedrock_role_arn: Arn of the Bedrock Knowledge Base Execution Role. :param collection_name: Name of the OpenSearch collection to apply the policies to. :param policy_name_suffix: EnvironmentID or other unique suffix to append to the policy name. """ encryption_policy_name = f"{naming_prefix}sp-{policy_name_suffix}" network_policy_name = f"{naming_prefix}np-{policy_name_suffix}" access_policy_name = f"{naming_prefix}ap-{policy_name_suffix}" def _create_security_policy(name, policy_type, policy): try: aoss_client.conn.create_security_policy(name=name, policy=json.dumps(policy), type=policy_type) except ClientError as e: if e.response["Error"]["Code"] == "ConflictException": log.info("OpenSearch security policy %s already exists.", name) raise def _create_access_policy(name, policy_type, policy): try: aoss_client.conn.create_access_policy(name=name, policy=json.dumps(policy), type=policy_type) except ClientError as e: if e.response["Error"]["Code"] == "ConflictException": log.info("OpenSearch data access policy %s already exists.", name) raise _create_security_policy( name=encryption_policy_name, policy_type="encryption", policy={ "Rules": [{"Resource": [f"collection/{collection_name}"], "ResourceType": "collection"}], "AWSOwnedKey": True, }, ) _create_security_policy( name=network_policy_name, policy_type="network", policy=[ { "Rules": [{"Resource": [f"collection/{collection_name}"], "ResourceType": "collection"}], "AllowFromPublic": True, } ], ) _create_access_policy( name=access_policy_name, policy_type="data", policy=[ { "Rules": [ { "Resource": [f"collection/{collection_name}"], "Permission": [ "aoss:CreateCollectionItems", "aoss:DeleteCollectionItems", "aoss:UpdateCollectionItems", "aoss:DescribeCollectionItems", ], "ResourceType": "collection", }, { "Resource": [f"index/{collection_name}/*"], "Permission": [ "aoss:CreateIndex", "aoss:DeleteIndex", "aoss:UpdateIndex", "aoss:DescribeIndex", "aoss:ReadDocument", "aoss:WriteDocument", ], "ResourceType": "index", }, ], "Principal": [(StsHook().conn.get_caller_identity()["Arn"]), bedrock_role_arn], } ], )
@task
[docs] def create_collection(collection_name: str): """ Call the Amazon OpenSearch Serverless API and create a collection with the provided name. :param collection_name: The name of the Collection to create. """ log.info("\nCreating collection: %s.", collection_name) return aoss_client.conn.create_collection(name=collection_name, type="VECTORSEARCH")[ "createCollectionDetail" ]["id"]
@task
[docs] def create_vector_index(index_name: str, collection_id: str, region: str): """ Use the OpenSearchPy client to create the vector index for the Amazon Open Search Serverless Collection. :param index_name: The vector index name to create. :param collection_id: ID of the collection to be indexed. :param region: Name of the AWS region the collection resides in. """ # Build the OpenSearch client oss_client = OpenSearch( hosts=[{"host": f"{collection_id}.{region}.aoss.amazonaws.com", "port": 443}], http_auth=AWSV4SignerAuth(boto3.Session().get_credentials(), region, "aoss"), use_ssl=True, verify_certs=True, connection_class=RequestsHttpConnection, timeout=300, ) index_config = { "settings": { "index.knn": "true", "number_of_shards": 1, "knn.algo_param.ef_search": 512, "number_of_replicas": 0, }, "mappings": { "properties": { "vector": { "type": "knn_vector", "dimension": 1536, "method": {"name": "hnsw", "engine": "faiss", "space_type": "l2"}, }, "text": {"type": "text"}, "text-metadata": {"type": "text"}, } }, } retries = 35 while retries > 0: try: response = oss_client.indices.create(index=index_name, body=json.dumps(index_config)) log.info("Creating index: %s.", response) break except AuthorizationException as e: # Index creation can take up to a minute and there is no (apparent?) way to check the current state. log.info( "Access denied; policy permissions have likely not yet propagated, %s tries remaining.", retries, ) log.debug(e) retries -= 1 if retries: sleep(2) else: raise
@task
[docs] def copy_data_to_s3(bucket: str): """ Download some sample data and upload it to 3S. :param bucket: Name of the Amazon S3 bucket to send the data to. """ # Monkey patch the list of names available for NamedTempFile so we can pick the names of the downloaded files. backup_get_candidate_names = tempfile._get_candidate_names # type: ignore[attr-defined] destinations = iter( [ "AMZN-2022-Shareholder-Letter.pdf", "AMZN-2021-Shareholder-Letter.pdf", "AMZN-2020-Shareholder-Letter.pdf", "AMZN-2019-Shareholder-Letter.pdf", ] ) tempfile._get_candidate_names = lambda: destinations # type: ignore[attr-defined] # Download the sample data files, save them as named temp files using the names above, and upload to S3. sources = [ "https://s2.q4cdn.com/299287126/files/doc_financials/2023/ar/2022-Shareholder-Letter.pdf", "https://s2.q4cdn.com/299287126/files/doc_financials/2022/ar/2021-Shareholder-Letter.pdf", "https://s2.q4cdn.com/299287126/files/doc_financials/2021/ar/Amazon-2020-Shareholder-Letter-and-1997-Shareholder-Letter.pdf", "https://s2.q4cdn.com/299287126/files/doc_financials/2020/ar/2019-Shareholder-Letter.pdf", ] for source in sources: with tempfile.NamedTemporaryFile(mode="w", prefix="") as data_file: urlretrieve(source, data_file.name) S3Hook().conn.upload_file( Filename=data_file.name, Bucket=bucket, Key=os.path.basename(data_file.name) ) # Revert the monkey patch. tempfile._get_candidate_names = backup_get_candidate_names # type: ignore[attr-defined] # Verify the path reversion worked. with tempfile.NamedTemporaryFile(mode="w", prefix=""): # If the reversion above did not apply correctly, this will fail with # a StopIteration error because the iterator will run out of names. ...
@task
[docs] def get_collection_arn(collection_id: str): """ Return a collection ARN for a given collection ID. :param collection_id: ID of the collection to be indexed. """ return next( colxn["arn"] for colxn in aoss_client.conn.list_collections()["collectionSummaries"] if colxn["id"] == collection_id )
# [START howto_operator_bedrock_delete_data_source] @task(trigger_rule=TriggerRule.ALL_DONE)
[docs] def delete_data_source(knowledge_base_id: str, data_source_id: str): """ Delete the Amazon Bedrock data source created earlier. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto_operator:BedrockDeleteDataSource` :param knowledge_base_id: The unique identifier of the knowledge base which the data source is attached to. :param data_source_id: The unique identifier of the data source to delete. """ log.info("Deleting data source %s from Knowledge Base %s.", data_source_id, knowledge_base_id) bedrock_agent_client.conn.delete_data_source( dataSourceId=data_source_id, knowledgeBaseId=knowledge_base_id )
# [END howto_operator_bedrock_delete_data_source] # [START howto_operator_bedrock_delete_knowledge_base] @task(trigger_rule=TriggerRule.ALL_DONE)
[docs] def delete_knowledge_base(knowledge_base_id: str): """ Delete the Amazon Bedrock knowledge base created earlier. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/operator:BedrockDeleteKnowledgeBase` :param knowledge_base_id: The unique identifier of the knowledge base to delete. """ log.info("Deleting Knowledge Base %s.", knowledge_base_id) bedrock_agent_client.conn.delete_knowledge_base(knowledgeBaseId=knowledge_base_id)
# [END howto_operator_bedrock_delete_knowledge_base] @task(trigger_rule=TriggerRule.ALL_DONE)
[docs] def delete_vector_index(index_name: str, collection_id: str): """ Delete the vector index created earlier. :param index_name: The name of the vector index to delete. :param collection_id: ID of the collection to be indexed. """ host = f"{collection_id}.{region_name}.aoss.amazonaws.com" credentials = boto3.Session().get_credentials() awsauth = AWSV4SignerAuth(credentials, region_name, "aoss") # Build the OpenSearch client oss_client = OpenSearch( hosts=[{"host": host, "port": 443}], http_auth=awsauth, use_ssl=True, verify_certs=True, connection_class=RequestsHttpConnection, timeout=300, ) oss_client.indices.delete(index=index_name)
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs] def delete_collection(collection_id: str): """ Delete the OpenSearch collection created earlier. :param collection_id: ID of the collection to be indexed. """ log.info("Deleting collection %s.", collection_id) aoss_client.conn.delete_collection(id=collection_id)
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs] def delete_opensearch_policies(collection_name: str): """ Delete the security, network and data access policies created earlier. :param collection_name: All policies in the given collection name will be deleted. """ access_policies = aoss_client.conn.list_access_policies( type="data", resource=[f"collection/{collection_name}"] )["accessPolicySummaries"] log.info("Found access policies for %s: %s", collection_name, access_policies) if not access_policies: raise Exception("No access policies found?") for policy in access_policies: log.info("Deleting access policy for %s: %s", collection_name, policy["name"]) aoss_client.conn.delete_access_policy(name=policy["name"], type="data") for policy_type in ["encryption", "network"]: policies = aoss_client.conn.list_security_policies( type=policy_type, resource=[f"collection/{collection_name}"] )["securityPolicySummaries"] if not policies: raise Exception("No security policies found?") log.info("Found %s security policies for %s: %s", policy_type, collection_name, policies) for policy in policies: log.info("Deleting %s security policy for %s: %s", policy_type, collection_name, policy["name"]) aoss_client.conn.delete_security_policy(name=policy["name"], type=policy_type)
with DAG( dag_id=DAG_ID, schedule="@once", start_date=datetime(2021, 1, 1), tags=["example"], catchup=False, ) as dag:
[docs] test_context = sys_test_context_task()
env_id = test_context["ENV_ID"] aoss_client = OpenSearchServerlessHook() bedrock_agent_client = BedrockAgentHook() region_name = boto3.session.Session().region_name naming_prefix = "bedrock-kb-" bucket_name = f"{naming_prefix}{env_id}" index_name = f"{naming_prefix}index-{env_id}" knowledge_base_name = f"{naming_prefix}{env_id}" vector_store_name = f"{naming_prefix}{env_id}" data_source_name = f"{naming_prefix}ds-{env_id}" create_bucket = S3CreateBucketOperator(task_id="create_bucket", bucket_name=bucket_name) opensearch_policies = create_opensearch_policies( bedrock_role_arn=test_context[ROLE_ARN_KEY], collection_name=vector_store_name, policy_name_suffix=env_id, ) collection = create_collection(collection_name=vector_store_name) # [START howto_sensor_opensearch_collection_active] await_collection = OpenSearchServerlessCollectionActiveSensor( task_id="await_collection", collection_name=vector_store_name, ) # [END howto_sensor_opensearch_collection_active] PROMPT = "What color is an orange?" # [START howto_operator_invoke_claude_model] invoke_claude_completions = BedrockInvokeModelOperator( task_id="invoke_claude_completions", model_id=CLAUDE_MODEL_ID, input_data={"max_tokens_to_sample": 4000, "prompt": f"\n\nHuman: {PROMPT}\n\nAssistant:"}, ) # [END howto_operator_invoke_claude_model] # [START howto_operator_bedrock_create_knowledge_base] create_knowledge_base = BedrockCreateKnowledgeBaseOperator( task_id="create_knowledge_base", name=knowledge_base_name, embedding_model_arn=f"arn:aws:bedrock:{region_name}::foundation-model/{TITAN_MODEL_ID}", role_arn=test_context[ROLE_ARN_KEY], storage_config={ "type": "OPENSEARCH_SERVERLESS", "opensearchServerlessConfiguration": { "collectionArn": get_collection_arn(collection), "vectorIndexName": index_name, "fieldMapping": { "vectorField": "vector", "textField": "text", "metadataField": "text-metadata", }, }, }, ) # [END howto_operator_bedrock_create_knowledge_base] create_knowledge_base.wait_for_completion = False # [START howto_sensor_bedrock_knowledge_base_active] await_knowledge_base = BedrockKnowledgeBaseActiveSensor( task_id="await_knowledge_base", knowledge_base_id=create_knowledge_base.output ) # [END howto_sensor_bedrock_knowledge_base_active] # [START howto_operator_bedrock_create_data_source] create_data_source = BedrockCreateDataSourceOperator( task_id="create_data_source", knowledge_base_id=create_knowledge_base.output, name=data_source_name, bucket_name=bucket_name, ) # [END howto_operator_bedrock_create_data_source] # In this demo, delete_data_source and delete_cluster are both trying to delete # the data from the S3 bucket and occasionally hitting a conflict. This ensures that # delete_data_source doesn't attempt to delete the files, leaving that duty to delete_bucket. create_data_source.create_data_source_kwargs["dataDeletionPolicy"] = "RETAIN" # [START howto_operator_bedrock_ingest_data] ingest_data = BedrockIngestDataOperator( task_id="ingest_data", knowledge_base_id=create_knowledge_base.output, data_source_id=create_data_source.output, ) # [END howto_operator_bedrock_ingest_data] ingest_data.wait_for_completion = False # [START howto_sensor_bedrock_ingest_data] await_ingest = BedrockIngestionJobSensor( task_id="await_ingest", knowledge_base_id=create_knowledge_base.output, data_source_id=create_data_source.output, ingestion_job_id=ingest_data.output, ) # [END howto_sensor_bedrock_ingest_data] # [START howto_operator_bedrock_knowledge_base_rag] knowledge_base_rag = BedrockRaGOperator( task_id="knowledge_base_rag", input="Who was the CEO of Amazon on 2022?", source_type="KNOWLEDGE_BASE", model_arn=f"arn:aws:bedrock:{region_name}::foundation-model/{CLAUDE_MODEL_ID}", knowledge_base_id=create_knowledge_base.output, ) # [END howto_operator_bedrock_knowledge_base_rag] # [START howto_operator_bedrock_retrieve] retrieve = BedrockRetrieveOperator( task_id="retrieve", knowledge_base_id=create_knowledge_base.output, retrieval_query="Who was the CEO of Amazon in 1997?", ) # [END howto_operator_bedrock_retrieve] delete_bucket = S3DeleteBucketOperator( task_id="delete_bucket", trigger_rule=TriggerRule.ALL_DONE, bucket_name=bucket_name, force_delete=True, ) chain( # TEST SETUP test_context, create_bucket, opensearch_policies, collection, await_collection, create_vector_index(index_name=index_name, collection_id=collection, region=region_name), copy_data_to_s3(bucket=bucket_name), # TEST BODY invoke_claude_completions, create_knowledge_base, await_knowledge_base, create_data_source, ingest_data, await_ingest, knowledge_base_rag, external_sources_rag_group(), retrieve, delete_data_source( knowledge_base_id=create_knowledge_base.output, data_source_id=create_data_source.output, ), delete_knowledge_base(knowledge_base_id=create_knowledge_base.output), # TEST TEARDOWN delete_vector_index(index_name=index_name, collection_id=collection), delete_opensearch_policies(collection_name=vector_store_name), delete_collection(collection_id=collection), delete_bucket, ) from tests_common.test_utils.watcher import watcher # This test needs watcher in order to properly mark success/failure # when "tearDown" task with trigger rule is part of the DAG list(dag.tasks) >> watcher() from tests_common.test_utils.system_tests import get_test_run # noqa: E402 # Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
[docs] test_run = get_test_run(dag)

Was this entry helpful?