이번 기술블로그의 주제는 AI&ML로써, AWS에서 제공하는 Gen AI 서비스인 AWS Bedrock 서비스를 이용하여 ChatBot 서비스를 테스트 해보겠습니다.
AWS Bedrock 서비스는 현재 서울리전에서는 지원하지 않는 관계로 모든 Hands-On은 버지니아 북부 리전에서 진행하도록 합니다.
AWS Bedrock 서비스는 최초 서비스 시작 시 다음과 같은 모델 엑세스를 요청하여야 하며, 본인의 용도에 맞는 모델을 선택합니다.
각 모델에 대한 사용비용은 다음과 같습니다.
https://aws.amazon.com/ko/bedrock/pricing/
이번 Hands-On 에서는 Claude 3 Haiku 모델을 사용할 예정이며, 다음과 같이 모델 액세스 요청을 제일 먼저 활성화 해야 합니다.
모델의 정상적인 동작 확인을 위해 플레이그라운드 메뉴로 이동합니다.
모델선택에서 Claude 3 Haiku를 선택합니다.
기본적인 질문에 대한 응답을 확인합니다. 현재까지 모델 언어에서 학습되어 있는 기본적인 답변을 확인할 수 있습니다.
이제 다음의 예시 System prompt 를 통해 Cluade가 답변할 형식을 지정할 수 있도록 설정 할 수 있습니다.
Prompt 형식 예시 |
입력한 형식의 기반하여 Claude가 답변을 도출한 내용
다음은 서울시 원문 정보 결재 내용 중 하나인 “[현장대응단] 2024년 7월 현장 소방공무원 직장훈련 사전(기본) 계획” 을 Prompt에 파일 형태로 사전 업로드 한 후 질의 응답한 내용입니다.
다음은 해당 파일을 업로드 한 후 Claude에서 답변한 내용입니다.
또한 Claude는 이미지 분석하여 답변을 할 수도 있습니다.
Bedrock 에 업로드 된 서울시 안심 의료 기관 운영 이미지
Bedrock에 사전 입력한 이미지를 토대로 질의 응답한 내용
위와 같이 이번 Hands-On 을 통해 AWS Bedrock 서비스의 기본적인 사용법과 Prompt 설정을 통한 대답 형식을 확인 할 수 있었습니다.
또한 아래의 예시 코드를 통해 간단한 호출 테스트를 수행할 수 있습니다.
package com.bedrock.bedrock.service;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import org.json.JSONObject;
import org.springframework.stereotype.Service;
import lombok.RequiredArgsConstructor;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrock.BedrockAsyncClient;
import software.amazon.awssdk.services.bedrock.BedrockClient;
import software.amazon.awssdk.services.bedrock.model.FoundationModelDetails;
import software.amazon.awssdk.services.bedrock.model.FoundationModelSummary;
import software.amazon.awssdk.services.bedrock.model.GetFoundationModelResponse;
import software.amazon.awssdk.services.bedrock.model.ListFoundationModelsResponse;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler;
import software.amazon.awssdk.services.bedrockruntime.model.ResponseStream;
@RequiredArgsConstructor
@Service
public class GenerateReplyServiceImpl implements GenerateReplyService{
@Override
public String generateReply(String message){
String accessKey = "";
String secretKey = "";
AwsBasicCredentials awsCredentials = AwsBasicCredentials.create(accessKey, secretKey);
BedrockClient bedrockClient = BedrockClient.builder()
.credentialsProvider(StaticCredentialsProvider.create(awsCredentials))
.region(software.amazon.awssdk.regions.Region.US_EAST_1)
.build();
//List<FoundationModelSummary> a = listFoundationModels(bedrockClient);
/**
start
*/
String modelId = "anthropic.claude-3-sonnet-20240229-v1:0";
System.out.println("Initializing the Amazon Bedrock async client...");
// System.out.printf("Region: %s%n", region.toString());
BedrockAsyncClient client = BedrockAsyncClient.builder()
.credentialsProvider(StaticCredentialsProvider.create(awsCredentials))
.region(software.amazon.awssdk.regions.Region.US_EAST_1)
.build();
getFoundationModel(client, modelId);
/*stop
*/
System.out.println("=".repeat(67));
System.out.println("Welcome to the Amazon Bedrock Runtime Demo with Anthropic Claude 3.");
System.out.println("=".repeat(67));
var prompt = "Hi, how are you?.";
System.out.println("Prompt: " + prompt);
System.out.println("-".repeat(67));
System.out.println("Streaming response:");
try {
JSONObject messagesApiResponse = invokeModelWithResponseStream(prompt);
System.out.println("\n" + "-".repeat(67));
System.out.println("Structured response:");
System.out.println(messagesApiResponse.toString(2));
} catch (Exception e) {
System.out.println("Couldn't invoke model using the Messages API, here's why: " + e.getMessage());
}
return "test";
}
public static JSONObject invokeModelWithResponseStream(String prompt) {
String accessKey = "";
String secretKey = "";
AwsBasicCredentials awsCredentials = AwsBasicCredentials.create(accessKey, secretKey);
BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder()
.region(software.amazon.awssdk.regions.Region.US_EAST_1)
.credentialsProvider(StaticCredentialsProvider.create(awsCredentials))
.build();
String modelId = "anthropic.claude-3-sonnet-20240229-v1:0k";
// Prepare the JSON payload for the Messages API request
var payload = new JSONObject()
.put("anthropic_version", "bedrock-2023-05-31")
.put("max_tokens", 1000)
.append("messages", new JSONObject()
.put("role", "user")
.append("content", new JSONObject()
.put("type", "text")
.put("text", prompt)
));
// Create the request object using the payload and the model ID
var request = InvokeModelWithResponseStreamRequest.builder()
.contentType("application/json")
.body(SdkBytes.fromUtf8String(payload.toString()))
.modelId(modelId)
.build();
// Create a handler to print the stream in real-time and add metadata to a response object
JSONObject structuredResponse = new JSONObject();
var handler = createMessagesApiResponseStreamHandler(structuredResponse);
// Invoke the model with the request payload and the response stream handler
client.invokeModelWithResponseStream(request, handler).join();
return structuredResponse;
}
private static InvokeModelWithResponseStreamResponseHandler createMessagesApiResponseStreamHandler(JSONObject structuredResponse) {
AtomicReference<String> completeMessage = new AtomicReference<>("");
Consumer<ResponseStream> responseStreamHandler = event -> event.accept(InvokeModelWithResponseStreamResponseHandler.Visitor.builder()
.onChunk(c -> {
// Decode the chunk
var chunk = new JSONObject(c.bytes().asUtf8String());
// The Messages API returns different types:
var chunkType = chunk.getString("type");
if ("message_start".equals(chunkType)) {
// The first chunk contains information about the message role
String role = chunk.optJSONObject("message").optString("role");
structuredResponse.put("role", role);
} else if ("content_block_delta".equals(chunkType)) {
// These chunks contain the text fragments
var text = chunk.optJSONObject("delta").optString("text");
// Print the text fragment to the console ...
System.out.print(text);
// ... and append it to the complete message
completeMessage.getAndUpdate(current -> current + text);
} else if ("message_delta".equals(chunkType)) {
// This chunk contains the stop reason
var stopReason = chunk.optJSONObject("delta").optString("stop_reason");
structuredResponse.put("stop_reason", stopReason);
} else if ("message_stop".equals(chunkType)) {
// The last chunk contains the metrics
JSONObject metrics = chunk.optJSONObject("amazon-bedrock-invocationMetrics");
structuredResponse.put("metrics", new JSONObject()
.put("inputTokenCount", metrics.optString("inputTokenCount"))
.put("outputTokenCount", metrics.optString("outputTokenCount"))
.put("firstByteLatency", metrics.optString("firstByteLatency"))
.put("invocationLatency", metrics.optString("invocationLatency")));
}
})
.build());
return InvokeModelWithResponseStreamResponseHandler.builder()
.onEventStream(stream -> stream.subscribe(responseStreamHandler))
.onComplete(() ->
// Add the complete message to the response object
structuredResponse.append("content", new JSONObject()
.put("type", "text")
.put("text", completeMessage.get())))
.build();
}
public static List<FoundationModelSummary> listFoundationModels(BedrockClient bedrockClient) {
try {
ListFoundationModelsResponse response = bedrockClient.listFoundationModels(r -> {});
List<FoundationModelSummary> models = response.modelSummaries();
if (models.isEmpty()) {
System.out.println("No available foundation models in ");
} else {
for (FoundationModelSummary model : models) {
System.out.println("Model ID: " + model.modelId());
System.out.println("Provider: " + model.providerName());
System.out.println("Name: " + model.modelName());
System.out.println();
}
}
return models;
} catch (SdkClientException e) {
System.err.println(e.getMessage());
throw new RuntimeException(e);
}
}
public static FoundationModelDetails getFoundationModel(BedrockAsyncClient bedrockClient, String modelIdentifier) {
try {
CompletableFuture<GetFoundationModelResponse> future = bedrockClient.getFoundationModel(
r -> r.modelIdentifier(modelIdentifier)
);
FoundationModelDetails model = future.get().modelDetails();
System.out.println(" Model ID: " + model.modelId());
System.out.println(" Model ARN: " + model.modelArn());
System.out.println(" Model Name: " + model.modelName());
System.out.println(" Provider Name: " + model.providerName());
// System.out.println(" Lifecycle status: " + model.modelLifecycle().statusAsString());
System.out.println(" Input modalities: " + model.inputModalities());
System.out.println(" Output modalities: " + model.outputModalities());
System.out.println(" Supported customizations: " + model.customizationsSupported());
System.out.println(" Supported inference types: " + model.inferenceTypesSupported());
System.out.println(" Response streaming supported: " + model.responseStreamingSupported());
return model;
} catch (ExecutionException e) {
if (e.getMessage().contains("ValidationException")) {
throw new IllegalArgumentException(e.getMessage());
} else {
System.err.println(e.getMessage());
throw new RuntimeException(e);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
System.err.println(e.getMessage());
throw new RuntimeException(e);
}
}
// snippet-end:[bedrock.java2.get_foundation_model_async.main]
}