Source code for camel.toolkits.dalle_toolkit

# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import base64
import os
import uuid
from io import BytesIO
from typing import List, Optional

from openai import OpenAI
from PIL import Image

from camel.logger import get_logger
from camel.toolkits import FunctionTool
from camel.toolkits.base import BaseToolkit
from camel.utils import MCPServer

logger = get_logger(__name__)


[docs] @MCPServer() class DalleToolkit(BaseToolkit): r"""A class representing a toolkit for image generation using OpenAI's DALL-E model. """ def __init__( self, timeout: Optional[float] = None, ): r"""Initializes a new instance of the DalleToolkit class. Args: timeout (Optional[float]): The timeout value for API requests in seconds. If None, no timeout is applied. (default: :obj:`None`) """ super().__init__(timeout=timeout)
[docs] def base64_to_image(self, base64_string: str) -> Optional[Image.Image]: r"""Converts a base64 encoded string into a PIL Image object. Args: base64_string (str): The base64 encoded string of the image. Returns: Optional[Image.Image]: The PIL Image object or None if conversion fails. """ try: # Decode the base64 string to get the image data image_data = base64.b64decode(base64_string) # Create a memory buffer for the image data image_buffer = BytesIO(image_data) # Open the image using the PIL library image = Image.open(image_buffer) return image except Exception as e: error_msg = ( f"An error occurred while converting base64 to image: {e}" ) logger.error(error_msg) return None
[docs] def image_path_to_base64(self, image_path: str) -> str: r"""Converts the file path of an image to a Base64 encoded string. Args: image_path (str): The path to the image file. Returns: str: A Base64 encoded string representing the content of the image file. """ try: with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') except Exception as e: error_msg = ( f"An error occurred while converting image path to base64: {e}" ) logger.error(error_msg) return error_msg
[docs] def image_to_base64(self, image: Image.Image) -> str: r"""Converts an image into a base64-encoded string. This function takes an image object as input, encodes the image into a PNG format base64 string, and returns it. If the encoding process encounters an error, it prints the error message and returns None. Args: image: The image object to be encoded, supports any image format that can be saved in PNG format. Returns: str: A base64-encoded string of the image. """ try: with BytesIO() as buffered_image: image.save(buffered_image, format="PNG") buffered_image.seek(0) image_bytes = buffered_image.read() base64_str = base64.b64encode(image_bytes).decode('utf-8') return base64_str except Exception as e: error_msg = f"An error occurred: {e}" logger.error(error_msg) return error_msg
[docs] def get_dalle_img(self, prompt: str, image_dir: str = "img") -> str: r"""Generate an image using OpenAI's DALL-E model. The generated image is saved to the specified directory. Args: prompt (str): The text prompt based on which the image is generated. image_dir (str): The directory to save the generated image. Defaults to 'img'. Returns: str: The path to the saved image. """ dalle_client = OpenAI() response = dalle_client.images.generate( model="dall-e-3", prompt=prompt, size="1024x1792", quality="standard", n=1, # NOTE: now dall-e-3 only supports n=1 response_format="b64_json", ) if response.data is None or len(response.data) == 0: error_msg = "No image data returned from DALL-E API." logger.error(error_msg) return error_msg image_b64 = response.data[0].b64_json image = self.base64_to_image(image_b64) # type: ignore[arg-type] if image is None: error_msg = "Failed to convert base64 string to image." logger.error(error_msg) return error_msg os.makedirs(image_dir, exist_ok=True) image_path = os.path.join(image_dir, f"{uuid.uuid4()}.png") image.save(image_path) return image_path
[docs] def get_tools(self) -> List[FunctionTool]: r"""Returns a list of FunctionTool objects representing the functions in the toolkit. Returns: List[FunctionTool]: A list of FunctionTool objects representing the functions in the toolkit. """ return [FunctionTool(self.get_dalle_img)]