Source code for camel.models.sglang_model

# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import logging
import subprocess
import threading
import time
from typing import Any, Dict, List, Optional, Type, Union

from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream
from pydantic import BaseModel

from camel.configs import SGLANG_API_PARAMS, SGLangConfig
from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.types import (
    ChatCompletion,
    ChatCompletionChunk,
    ModelType,
)
from camel.utils import BaseTokenCounter, OpenAITokenCounter


[docs] class SGLangModel(BaseModelBackend): r"""SGLang service interface. Args: model_type (Union[ModelType, str]): Model for which a backend is created. model_config_dict (Optional[Dict[str, Any]], optional): A dictionary that will be fed into:obj:`openai.ChatCompletion.create()`. If :obj:`None`, :obj:`SGLangConfig().as_dict()` will be used. (default: :obj:`None`) api_key (Optional[str], optional): The API key for authenticating with the model service. SGLang doesn't need API key, it would be ignored if set. (default: :obj:`None`) url (Optional[str], optional): The url to the model service. If not provided, :obj:`"http://127.0.0.1:30000/v1"` will be used. (default: :obj:`None`) token_counter (Optional[BaseTokenCounter], optional): Token counter to use for the model. If not provided, :obj:`OpenAITokenCounter( ModelType.GPT_4O_MINI)` will be used. (default: :obj:`None`) Reference: https://sgl-project.github.io/backend/openai_api_completions.html """ def __init__( self, model_type: Union[ModelType, str], model_config_dict: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None, url: Optional[str] = None, token_counter: Optional[BaseTokenCounter] = None, ) -> None: if model_config_dict is None: model_config_dict = SGLangConfig().as_dict() self.server_process = None self.last_run_time: Optional[float] = ( None # Will be set when the server starts ) self._lock = threading.Lock() self._inactivity_thread: Optional[threading.Thread] = None super().__init__( model_type, model_config_dict, api_key, url, token_counter ) self._client = None if self._url: # Initialize the client if an existing URL is provided self._client = OpenAI( timeout=180, max_retries=3, api_key="Set-but-ignored", # required but ignored base_url=self._url, ) self._async_client = AsyncOpenAI( timeout=180, max_retries=3, api_key="Set-but-ignored", # required but ignored base_url=self._url, ) def _start_server(self) -> None: try: if not self._url: tool_call_flag = self.model_config_dict.get("tools") tool_call_arg = ( f"--tool-call-parser {self._api_key} " if tool_call_flag else "" ) cmd = ( f"python -m sglang.launch_server " f"--model-path {self.model_type} " f"{tool_call_arg}" f"--port 30000 " f"--host 0.0.0.0" ) server_process = _execute_shell_command(cmd) _wait_for_server("http://localhost:30000") self._url = "http://127.0.0.1:30000/v1" self.server_process = server_process # type: ignore[assignment] # Start the inactivity monitor in a background thread self._inactivity_thread = threading.Thread( target=self._monitor_inactivity, daemon=True ) self._inactivity_thread.start() self.last_run_time = time.time() # Initialize the client after the server starts self._client = OpenAI( timeout=180, max_retries=3, api_key="Set-but-ignored", # required but ignored base_url=self._url, ) except Exception as e: raise RuntimeError(f"Failed to start SGLang server: {e}") from e def _ensure_server_running(self) -> None: r"""Ensures that the server is running. If not, starts the server.""" with self._lock: if self.server_process is None: self._start_server() def _monitor_inactivity(self): r"""Monitor whether the server process has been inactive for over 10 minutes. """ while True: # Check every 10 seconds time.sleep(10) # Over 10 minutes with self._lock: # Over 10 minutes if self.last_run_time and ( time.time() - self.last_run_time > 600 ): if self.server_process: _terminate_process(self.server_process) self.server_process = None self._client = None # Invalidate the client logging.info( "Server process terminated due to inactivity." ) break @property def token_counter(self) -> BaseTokenCounter: r"""Initialize the token counter for the model backend. Returns: BaseTokenCounter: The token counter following the model's tokenization style. """ if not self._token_counter: self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI) return self._token_counter
[docs] def check_model_config(self): r"""Check whether the model configuration contains any unexpected arguments to SGLang API. Raises: ValueError: If the model configuration dictionary contains any unexpected arguments to OpenAI API. """ for param in self.model_config_dict: if param not in SGLANG_API_PARAMS: raise ValueError( f"Unexpected argument `{param}` is " "input into SGLang model backend." )
async def _arun( self, messages: List[OpenAIMessage], response_format: Optional[Type[BaseModel]] = None, tools: Optional[List[Dict[str, Any]]] = None, ) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]: r"""Runs inference of OpenAI chat completion. Args: messages (List[OpenAIMessage]): Message list with the chat history in OpenAI API format. Returns: Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]: `ChatCompletion` in the non-stream mode, or `AsyncStream[ChatCompletionChunk]` in the stream mode. """ # Ensure server is running self._ensure_server_running() with self._lock: # Update last run time self.last_run_time = time.time() if self._client is None: raise RuntimeError( "Client is not initialized. Ensure the server is running." ) response = await self._async_client.chat.completions.create( messages=messages, model=self.model_type, **self.model_config_dict, ) return response def _run( self, messages: List[OpenAIMessage], response_format: Optional[Type[BaseModel]] = None, tools: Optional[List[Dict[str, Any]]] = None, ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]: r"""Runs inference of OpenAI chat completion. Args: messages (List[OpenAIMessage]): Message list with the chat history in OpenAI API format. Returns: Union[ChatCompletion, Stream[ChatCompletionChunk]]: `ChatCompletion` in the non-stream mode, or `Stream[ChatCompletionChunk]` in the stream mode. """ # Ensure server is running self._ensure_server_running() with self._lock: # Update last run time self.last_run_time = time.time() if self._client is None: raise RuntimeError( "Client is not initialized. Ensure the server is running." ) response = self._client.chat.completions.create( messages=messages, model=self.model_type, **self.model_config_dict, ) return response @property def stream(self) -> bool: r"""Returns whether the model is in stream mode, which sends partial results each time. Returns: bool: Whether the model is in stream mode. """ return self.model_config_dict.get('stream', False) def __del__(self): r"""Properly clean up resources when the model is destroyed.""" self.cleanup()
[docs] def cleanup(self): r"""Terminate the server process and clean up resources.""" with self._lock: if self.server_process: _terminate_process(self.server_process) self.server_process = None self._client = None logging.info("Server process terminated during cleanup.")
# Below are helper functions from sglang.utils def _terminate_process(process): _kill_process_tree(process.pid) def _kill_process_tree( parent_pid, include_parent: bool = True, skip_pid: Optional[int] = None ): r"""Kill the process and all its child processes.""" import os import signal import psutil if parent_pid is None: parent_pid = os.getpid() include_parent = False try: itself = psutil.Process(parent_pid) except psutil.NoSuchProcess: return children = itself.children(recursive=True) for child in children: if child.pid == skip_pid: continue try: child.kill() except psutil.NoSuchProcess: pass if include_parent: try: itself.kill() # Sometime processes cannot be killed with SIGKILL # so we send an additional signal to kill them. if hasattr(signal, "SIGQUIT"): itself.send_signal(signal.SIGQUIT) else: itself.send_signal(signal.SIGTERM) except psutil.NoSuchProcess: pass def _execute_shell_command(command: str) -> subprocess.Popen: r"""Execute a shell command and return the process handle Args: command: Shell command as a string (can include \\ line continuations) Returns: subprocess.Popen: Process handle """ import subprocess # Replace \ newline with space and split command = command.replace("\\\n", " ").replace("\\", " ") parts = command.split() return subprocess.Popen(parts, text=True, stderr=subprocess.STDOUT) def _wait_for_server(base_url: str, timeout: Optional[int] = 30) -> None: r"""Wait for the server to be ready by polling the /v1/models endpoint. Args: base_url (str): The base URL of the server timeout (Optional[int]): Maximum time to wait in seconds. (default: :obj:`30`) """ import requests # Set a default value if timeout is None actual_timeout = 30 if timeout is None else timeout start_time = time.time() while True: try: response = requests.get( f"{base_url}/v1/models", headers={"Authorization": "Bearer None"}, timeout=5, # Add a timeout for the request itself ) if response.status_code == 200: time.sleep(5) print( """\n NOTE: Typically, the server runs in a separate terminal. In this notebook, we run the server and notebook code together, so their outputs are combined. To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue. """ ) break if time.time() - start_time > actual_timeout: raise TimeoutError( f"Server did not become ready within " f"{actual_timeout} seconds" ) except (requests.exceptions.RequestException, TimeoutError) as e: if time.time() - start_time > actual_timeout: raise TimeoutError( f"Server did not become ready within " f"{actual_timeout} seconds: {e}" ) time.sleep(1)