[Hands-On] AWS Bed-rock을 이용한 생성 형 AI 기초 실습

파운데이션 모델(FM)을 사용하여 생성형 AI 애플리케이션을 구축 - Amazon Bedrock - AWS
Amazon Bedrock은 AI21 Labs, Anthropic, Cohere, Meta, Mistral AI, Stability AI 및 Amazon과 같은 선도적인 AI 회사의 다양한 고성능 파운데이션 모델(FM)을 단일 API를 통해 제공하는 완전 관리형 서비스입니다. 이 서비스를 사용하면 보안, 개인정보 보호 및 책임형 AI를 포함하여 생성형 AI 애플리케이션을 구축하는 데 필요한 광범위한 기능 세트를 활용합니다. Amazon Bedrock을 사용하면 사용 사례에 맞게 상위 FM을 쉽게 실험 및 평가하고, 미세 조정 및 검색 증강 생성(RAG)과 같은 기술을 사용하여 데이터로 비공개로 사용자 지정하고, 엔터프라이즈 시스템 및 데이터 소스를 사용하여 작업을 실행하는 에이전트를 구축할 수 있습니다. Amazon Bedrock은 서버리스 이므로 인프라를 관리할 필요가 없으며 이미 익숙한 AWS 서비스를 사용하여 생성형 AI 기능을 애플리케이션에 안전하게 통합하고 배포할 수 있습니다.

이번 기술블로그의 주제는 AI&ML로써, AWS에서 제공하는 Gen AI 서비스인 AWS Bedrock 서비스를 이용하여 ChatBot 서비스를 테스트 해보겠습니다.

AWS Bedrock 서비스는 현재 서울리전에서는 지원하지 않는 관계로 모든 Hands-On은 버지니아 북부 리전에서 진행하도록 합니다.

AWS Bedrock 서비스는 최초 서비스 시작 시 다음과 같은 모델 엑세스를 요청하여야 하며, 본인의 용도에 맞는 모델을 선택합니다.

각 모델에 대한 사용비용은 다음과 같습니다.
https://aws.amazon.com/ko/bedrock/pricing/

이번 Hands-On 에서는 Claude 3 Haiku 모델을 사용할 예정이며, 다음과 같이 모델 액세스 요청을 제일 먼저 활성화 해야 합니다.

모델의 정상적인 동작 확인을 위해 플레이그라운드 메뉴로 이동합니다.

모델선택에서 Claude 3 Haiku를 선택합니다.

기본적인 질문에 대한 응답을 확인합니다. 현재까지 모델 언어에서 학습되어 있는 기본적인 답변을 확인할 수 있습니다.

이제 다음의 예시 System prompt 를 통해 Cluade가 답변할 형식을 지정할 수 있도록 설정 할 수 있습니다.

Prompt 형식 예시

입력한 형식의 기반하여 Claude가 답변을 도출한 내용

다음은 서울시 원문 정보 결재 내용 중 하나인 “[현장대응단] 2024년 7월 현장 소방공무원 직장훈련 사전(기본) 계획” 을 Prompt에 파일 형태로 사전 업로드 한 후 질의 응답한 내용입니다.

다음은 해당 파일을 업로드 한 후 Claude에서 답변한 내용입니다.

또한 Claude는 이미지 분석하여 답변을 할 수도 있습니다.

Bedrock 에 업로드 된 서울시 안심 의료 기관 운영 이미지

Bedrock에 사전 입력한 이미지를 토대로 질의 응답한 내용

위와 같이 이번 Hands-On 을 통해 AWS Bedrock 서비스의 기본적인 사용법과 Prompt 설정을 통한 대답 형식을 확인 할 수 있었습니다.

또한 아래의 예시 코드를 통해 간단한 호출 테스트를 수행할 수 있습니다.

package com.bedrock.bedrock.service;
 
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
 
import org.json.JSONObject;
import org.springframework.stereotype.Service;
 
import lombok.RequiredArgsConstructor;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrock.BedrockAsyncClient;
import software.amazon.awssdk.services.bedrock.BedrockClient;
import software.amazon.awssdk.services.bedrock.model.FoundationModelDetails;
import software.amazon.awssdk.services.bedrock.model.FoundationModelSummary;
import software.amazon.awssdk.services.bedrock.model.GetFoundationModelResponse;
import software.amazon.awssdk.services.bedrock.model.ListFoundationModelsResponse;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler;
import software.amazon.awssdk.services.bedrockruntime.model.ResponseStream;
 
@RequiredArgsConstructor
@Service
public class GenerateReplyServiceImpl implements GenerateReplyService{
    @Override
    public String generateReply(String message){
        String accessKey = "";
        String secretKey = "";
        AwsBasicCredentials awsCredentials = AwsBasicCredentials.create(accessKey, secretKey);
        
        BedrockClient bedrockClient = BedrockClient.builder()
                .credentialsProvider(StaticCredentialsProvider.create(awsCredentials))
                .region(software.amazon.awssdk.regions.Region.US_EAST_1)
                .build();
        
        //List<FoundationModelSummary> a = listFoundationModels(bedrockClient);
        /**
        start
         */
        
 
        String modelId = "anthropic.claude-3-sonnet-20240229-v1:0";
 
        System.out.println("Initializing the Amazon Bedrock async client...");
        // System.out.printf("Region: %s%n", region.toString());
 
        BedrockAsyncClient client = BedrockAsyncClient.builder()
                .credentialsProvider(StaticCredentialsProvider.create(awsCredentials))
                .region(software.amazon.awssdk.regions.Region.US_EAST_1)
                .build();
 
        getFoundationModel(client, modelId);
 
        /*stop
         */
        System.out.println("=".repeat(67));
        System.out.println("Welcome to the Amazon Bedrock Runtime Demo with Anthropic Claude 3.");
        System.out.println("=".repeat(67));
 
        var prompt = "Hi, how are you?.";
        System.out.println("Prompt: " + prompt);
 
        System.out.println("-".repeat(67));
        System.out.println("Streaming response:");
        try {
            JSONObject messagesApiResponse = invokeModelWithResponseStream(prompt);
 
            System.out.println("\n" + "-".repeat(67));
            System.out.println("Structured response:");
            System.out.println(messagesApiResponse.toString(2));
 
        } catch (Exception e) {
            System.out.println("Couldn't invoke model using the Messages API, here's why: " + e.getMessage());
        }
 
 
        
        return "test";
    }
    public static JSONObject invokeModelWithResponseStream(String prompt) {
        String accessKey = "";
        String secretKey = "";
        AwsBasicCredentials awsCredentials = AwsBasicCredentials.create(accessKey, secretKey);
        
        BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder()
                                .region(software.amazon.awssdk.regions.Region.US_EAST_1)
                                .credentialsProvider(StaticCredentialsProvider.create(awsCredentials))
                                .build();
 
        String modelId = "anthropic.claude-3-sonnet-20240229-v1:0k";
 
        // Prepare the JSON payload for the Messages API request
        var payload = new JSONObject()
                .put("anthropic_version", "bedrock-2023-05-31")
                .put("max_tokens", 1000)
                .append("messages", new JSONObject()
                        .put("role", "user")
                        .append("content", new JSONObject()
                                .put("type", "text")
                                .put("text", prompt)
                        ));
 
        // Create the request object using the payload and the model ID
        var request = InvokeModelWithResponseStreamRequest.builder()
                .contentType("application/json")
                .body(SdkBytes.fromUtf8String(payload.toString()))
                .modelId(modelId)
                .build();
 
        // Create a handler to print the stream in real-time and add metadata to a response object
        JSONObject structuredResponse = new JSONObject();
        var handler = createMessagesApiResponseStreamHandler(structuredResponse);
 
        // Invoke the model with the request payload and the response stream handler
        client.invokeModelWithResponseStream(request, handler).join();
 
        return structuredResponse;
    }
 
    private static InvokeModelWithResponseStreamResponseHandler createMessagesApiResponseStreamHandler(JSONObject structuredResponse) {
        AtomicReference<String> completeMessage = new AtomicReference<>("");
 
        Consumer<ResponseStream> responseStreamHandler = event -> event.accept(InvokeModelWithResponseStreamResponseHandler.Visitor.builder()
                .onChunk(c -> {
                    // Decode the chunk
                    var chunk = new JSONObject(c.bytes().asUtf8String());
 
                    // The Messages API returns different types:
                    var chunkType = chunk.getString("type");
                    if ("message_start".equals(chunkType)) {
                        // The first chunk contains information about the message role
                        String role = chunk.optJSONObject("message").optString("role");
                        structuredResponse.put("role", role);
 
                    } else if ("content_block_delta".equals(chunkType)) {
                        // These chunks contain the text fragments
                        var text = chunk.optJSONObject("delta").optString("text");
                        // Print the text fragment to the console ...
                        System.out.print(text);
                        // ... and append it to the complete message
                        completeMessage.getAndUpdate(current -> current + text);
 
                    } else if ("message_delta".equals(chunkType)) {
                        // This chunk contains the stop reason
                        var stopReason = chunk.optJSONObject("delta").optString("stop_reason");
                        structuredResponse.put("stop_reason", stopReason);
 
                    } else if ("message_stop".equals(chunkType)) {
                        // The last chunk contains the metrics
                        JSONObject metrics = chunk.optJSONObject("amazon-bedrock-invocationMetrics");
                        structuredResponse.put("metrics", new JSONObject()
                                .put("inputTokenCount", metrics.optString("inputTokenCount"))
                                .put("outputTokenCount", metrics.optString("outputTokenCount"))
                                .put("firstByteLatency", metrics.optString("firstByteLatency"))
                                .put("invocationLatency", metrics.optString("invocationLatency")));
                    }
                })
                .build());
 
        return InvokeModelWithResponseStreamResponseHandler.builder()
                .onEventStream(stream -> stream.subscribe(responseStreamHandler))
                .onComplete(() ->
                        // Add the complete message to the response object
                        structuredResponse.append("content", new JSONObject()
                                .put("type", "text")
                                .put("text", completeMessage.get())))
                .build();
    }
 
    public static List<FoundationModelSummary> listFoundationModels(BedrockClient bedrockClient) {
 
        try {
            ListFoundationModelsResponse response = bedrockClient.listFoundationModels(r -> {});
 
            List<FoundationModelSummary> models = response.modelSummaries();
 
            if (models.isEmpty()) {
                System.out.println("No available foundation models in ");
            } else {
                for (FoundationModelSummary model : models) {
                    System.out.println("Model ID: " + model.modelId());
                    System.out.println("Provider: " + model.providerName());
                    System.out.println("Name:     " + model.modelName());
                    System.out.println();
                }
            }
 
            return models;
 
        } catch (SdkClientException e) {
            System.err.println(e.getMessage());
            throw new RuntimeException(e);
        }
    }
    public static FoundationModelDetails getFoundationModel(BedrockAsyncClient bedrockClient, String modelIdentifier) {
        try {
            CompletableFuture<GetFoundationModelResponse> future = bedrockClient.getFoundationModel(
                    r -> r.modelIdentifier(modelIdentifier)
            );
 
            FoundationModelDetails model = future.get().modelDetails();
 
            System.out.println(" Model ID:                     " + model.modelId());
            System.out.println(" Model ARN:                    " + model.modelArn());
            System.out.println(" Model Name:                   " + model.modelName());
            System.out.println(" Provider Name:                " + model.providerName());
            // System.out.println(" Lifecycle status:             " + model.modelLifecycle().statusAsString());
            System.out.println(" Input modalities:             " + model.inputModalities());
            System.out.println(" Output modalities:            " + model.outputModalities());
            System.out.println(" Supported customizations:     " + model.customizationsSupported());
            System.out.println(" Supported inference types:    " + model.inferenceTypesSupported());
            System.out.println(" Response streaming supported: " + model.responseStreamingSupported());
 
            return model;
 
        } catch (ExecutionException e) {
            if (e.getMessage().contains("ValidationException")) {
                throw new IllegalArgumentException(e.getMessage());
            } else {
                System.err.println(e.getMessage());
                throw new RuntimeException(e);
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            System.err.println(e.getMessage());
            throw new RuntimeException(e);
        }
    }
    // snippet-end:[bedrock.java2.get_foundation_model_async.main]
}

댓글 달기

이메일 주소는 공개되지 않습니다. 필수 항목은 *(으)로 표시합니다