Source code for camel.toolkits.dalle_toolkit
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import base64
import os
import uuid
from io import BytesIO
from typing import List, Optional
from openai import OpenAI
from PIL import Image
from camel.logger import get_logger
from camel.toolkits import FunctionTool
from camel.toolkits.base import BaseToolkit
from camel.utils import MCPServer
logger = get_logger(__name__)
[docs]
@MCPServer()
class DalleToolkit(BaseToolkit):
r"""A class representing a toolkit for image generation using OpenAI's
DALL-E model.
"""
def __init__(
self,
timeout: Optional[float] = None,
):
r"""Initializes a new instance of the DalleToolkit class.
Args:
timeout (Optional[float]): The timeout value for API requests
in seconds. If None, no timeout is applied.
(default: :obj:`None`)
"""
super().__init__(timeout=timeout)
[docs]
def base64_to_image(self, base64_string: str) -> Optional[Image.Image]:
r"""Converts a base64 encoded string into a PIL Image object.
Args:
base64_string (str): The base64 encoded string of the image.
Returns:
Optional[Image.Image]: The PIL Image object or None if conversion
fails.
"""
try:
# Decode the base64 string to get the image data
image_data = base64.b64decode(base64_string)
# Create a memory buffer for the image data
image_buffer = BytesIO(image_data)
# Open the image using the PIL library
image = Image.open(image_buffer)
return image
except Exception as e:
error_msg = (
f"An error occurred while converting base64 to image: {e}"
)
logger.error(error_msg)
return None
[docs]
def image_path_to_base64(self, image_path: str) -> str:
r"""Converts the file path of an image to a Base64 encoded string.
Args:
image_path (str): The path to the image file.
Returns:
str: A Base64 encoded string representing the content of the image
file.
"""
try:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
except Exception as e:
error_msg = (
f"An error occurred while converting image path to base64: {e}"
)
logger.error(error_msg)
return error_msg
[docs]
def image_to_base64(self, image: Image.Image) -> str:
r"""Converts an image into a base64-encoded string.
This function takes an image object as input, encodes the image into a
PNG format base64 string, and returns it.
If the encoding process encounters an error, it prints the error
message and returns None.
Args:
image: The image object to be encoded, supports any image format
that can be saved in PNG format.
Returns:
str: A base64-encoded string of the image.
"""
try:
with BytesIO() as buffered_image:
image.save(buffered_image, format="PNG")
buffered_image.seek(0)
image_bytes = buffered_image.read()
base64_str = base64.b64encode(image_bytes).decode('utf-8')
return base64_str
except Exception as e:
error_msg = f"An error occurred: {e}"
logger.error(error_msg)
return error_msg
[docs]
def get_dalle_img(self, prompt: str, image_dir: str = "img") -> str:
r"""Generate an image using OpenAI's DALL-E model.
The generated image is saved to the specified directory.
Args:
prompt (str): The text prompt based on which the image is
generated.
image_dir (str): The directory to save the generated image.
Defaults to 'img'.
Returns:
str: The path to the saved image.
"""
dalle_client = OpenAI()
response = dalle_client.images.generate(
model="dall-e-3",
prompt=prompt,
size="1024x1792",
quality="standard",
n=1, # NOTE: now dall-e-3 only supports n=1
response_format="b64_json",
)
if response.data is None or len(response.data) == 0:
error_msg = "No image data returned from DALL-E API."
logger.error(error_msg)
return error_msg
image_b64 = response.data[0].b64_json
image = self.base64_to_image(image_b64) # type: ignore[arg-type]
if image is None:
error_msg = "Failed to convert base64 string to image."
logger.error(error_msg)
return error_msg
os.makedirs(image_dir, exist_ok=True)
image_path = os.path.join(image_dir, f"{uuid.uuid4()}.png")
image.save(image_path)
return image_path
[docs]
def get_tools(self) -> List[FunctionTool]:
r"""Returns a list of FunctionTool objects representing the
functions in the toolkit.
Returns:
List[FunctionTool]: A list of FunctionTool objects
representing the functions in the toolkit.
"""
return [FunctionTool(self.get_dalle_img)]