Source code for camel.models.gemini_model

# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from camel.configs import Gemini_API_PARAMS
from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.types import (
    ChatCompletion,
    ChatCompletionMessage,
    Choice,
    ModelType,
)
from camel.utils import (
    BaseTokenCounter,
    GeminiTokenCounter,
    api_keys_required,
)

if TYPE_CHECKING:
    from google.generativeai.types import ContentsType, GenerateContentResponse


[docs] class GeminiModel(BaseModelBackend): r"""Gemini API in a unified BaseModelBackend interface.""" # NOTE: Currently "stream": True is not supported with Gemini due to the # limitation of the current camel design. def __init__( self, model_type: ModelType, model_config_dict: Dict[str, Any], api_key: Optional[str] = None, url: Optional[str] = None, token_counter: Optional[BaseTokenCounter] = None, ) -> None: r"""Constructor for Gemini backend. Args: model_type (ModelType): Model for which a backend is created. model_config_dict (Dict[str, Any]): A dictionary that will be fed into generate_content(). api_key (Optional[str]): The API key for authenticating with the gemini service. (default: :obj:`None`) url (Optional[str]): The url to the gemini service. token_counter (Optional[BaseTokenCounter]): Token counter to use for the model. If not provided, `GeminiTokenCounter` will be used. """ import os import google.generativeai as genai from google.generativeai.types.generation_types import GenerationConfig super().__init__( model_type, model_config_dict, api_key, url, token_counter ) self._api_key = api_key or os.environ.get("GOOGLE_API_KEY") genai.configure(api_key=self._api_key) self._client = genai.GenerativeModel(self.model_type.value) keys = list(self.model_config_dict.keys()) generation_config_dict = { k: self.model_config_dict.pop(k) for k in keys if hasattr(GenerationConfig, k) } generation_config = genai.types.GenerationConfig( **generation_config_dict ) self.model_config_dict["generation_config"] = generation_config @property def token_counter(self) -> BaseTokenCounter: r"""Initialize the token counter for the model backend. Returns: BaseTokenCounter: The token counter following the model's tokenization style. """ if not self._token_counter: self._token_counter = GeminiTokenCounter(self.model_type) return self._token_counter
[docs] @api_keys_required("GOOGLE_API_KEY") def run( self, messages: List[OpenAIMessage], ) -> ChatCompletion: r"""Runs inference of Gemini model. This method can handle multimodal input Args: messages: Message list or Message with the chat history in OpenAi format. Returns: response: A ChatCompletion object formatted for the OpenAI API. """ response = self._client.generate_content( contents=self.to_gemini_req(messages), **self.model_config_dict, ) response.resolve() return self.to_openai_response(response)
[docs] def check_model_config(self): r"""Check whether the model configuration contains any unexpected arguments to Gemini API. Raises: ValueError: If the model configuration dictionary contains any unexpected arguments to OpenAI API. """ if self.model_config_dict is not None: for param in self.model_config_dict: if param not in Gemini_API_PARAMS: raise ValueError( f"Unexpected argument `{param}` is " "input into Gemini model backend." )
@property def stream(self) -> bool: r"""Returns whether the model is in stream mode, which sends partial results each time. Returns: bool: Whether the model is in stream mode. """ return self.model_config_dict.get('stream', False)
[docs] def to_gemini_req(self, messages: List[OpenAIMessage]) -> 'ContentsType': r"""Converts the request from the OpenAI API format to the Gemini API request format. Args: messages: The request object from the OpenAI API. Returns: converted_messages: A list of messages formatted for Gemini API. """ # role reference # https://ai.google.dev/api/python/google/generativeai/protos/Content converted_messages = [] for message in messages: role = message.get('role') if role == 'assistant': role_to_gemini = 'model' else: role_to_gemini = 'user' converted_message = { "role": role_to_gemini, "parts": message.get("content"), } converted_messages.append(converted_message) return converted_messages
[docs] def to_openai_response( self, response: 'GenerateContentResponse', ) -> ChatCompletion: r"""Converts the response from the Gemini API to the OpenAI API response format. Args: response: The response object returned by the Gemini API Returns: openai_response: A ChatCompletion object formatted for the OpenAI API. """ import time import uuid openai_response = ChatCompletion( id=f"chatcmpl-{uuid.uuid4().hex!s}", object="chat.completion", created=int(time.time()), model=self.model_type.value, choices=[], ) for i, candidate in enumerate(response.candidates): content = "" if candidate.content and len(candidate.content.parts) > 0: content = candidate.content.parts[0].text finish_reason = candidate.finish_reason finish_reason_mapping = { "FinishReason.STOP": "stop", "FinishReason.SAFETY": "content_filter", "FinishReason.RECITATION": "content_filter", "FinishReason.MAX_TOKENS": "length", } finish_reason = finish_reason_mapping.get(finish_reason, "stop") choice = Choice( index=i, message=ChatCompletionMessage( role="assistant", content=content ), finish_reason=finish_reason, ) openai_response.choices.append(choice) return openai_response