# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import os
import time
import uuid
from typing import Any, Dict, List, Optional, Union
import httpx
from openai import OpenAI, Stream
from camel.configs import (
SAMBA_CLOUD_API_PARAMS,
SAMBA_VERSE_API_PARAMS,
SambaCloudAPIConfig,
)
from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.types import (
ChatCompletion,
ChatCompletionChunk,
CompletionUsage,
ModelType,
)
from camel.utils import (
BaseTokenCounter,
OpenAITokenCounter,
api_keys_required,
)
try:
if os.getenv("AGENTOPS_API_KEY") is not None:
from agentops import LLMEvent, record
else:
raise ImportError
except (ImportError, AttributeError):
LLMEvent = None
[docs]
class SambaModel(BaseModelBackend):
r"""SambaNova service interface.
Args:
model_type (Union[ModelType, str]): Model for which a SambaNova backend
is created. Supported models via SambaNova Cloud:
`https://community.sambanova.ai/t/supported-models/193`.
Supported models via SambaVerse API is listed in
`https://sambaverse.sambanova.ai/models`.
model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
that will be fed into:obj:`openai.ChatCompletion.create()`. If
:obj:`None`, :obj:`SambaCloudAPIConfig().as_dict()` will be used.
(default: :obj:`None`)
api_key (Optional[str], optional): The API key for authenticating
with the SambaNova service. (default: :obj:`None`)
url (Optional[str], optional): The url to the SambaNova service.
Current support SambaVerse API:
:obj:`"https://sambaverse.sambanova.ai/api/predict"` and
SambaNova Cloud:
:obj:`"https://api.sambanova.ai/v1"` (default: :obj:`https://api.
sambanova.ai/v1`)
token_counter (Optional[BaseTokenCounter], optional): Token counter to
use for the model. If not provided, :obj:`OpenAITokenCounter(
ModelType.GPT_4O_MINI)` will be used.
"""
@api_keys_required(
[
("api_key", 'SAMBA_API_KEY'),
]
)
def __init__(
self,
model_type: Union[ModelType, str],
model_config_dict: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
url: Optional[str] = None,
token_counter: Optional[BaseTokenCounter] = None,
) -> None:
if model_config_dict is None:
model_config_dict = SambaCloudAPIConfig().as_dict()
api_key = api_key or os.environ.get("SAMBA_API_KEY")
url = url or os.environ.get(
"SAMBA_API_BASE_URL",
"https://api.sambanova.ai/v1",
)
super().__init__(
model_type, model_config_dict, api_key, url, token_counter
)
if self._url == "https://api.sambanova.ai/v1":
self._client = OpenAI(
timeout=180,
max_retries=3,
base_url=self._url,
api_key=self._api_key,
)
@property
def token_counter(self) -> BaseTokenCounter:
r"""Initialize the token counter for the model backend.
Returns:
BaseTokenCounter: The token counter following the model's
tokenization style.
"""
if not self._token_counter:
self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI)
return self._token_counter
[docs]
def check_model_config(self):
r"""Check whether the model configuration contains any
unexpected arguments to SambaNova API.
Raises:
ValueError: If the model configuration dictionary contains any
unexpected arguments to SambaNova API.
"""
if self._url == "https://sambaverse.sambanova.ai/api/predict":
for param in self.model_config_dict:
if param not in SAMBA_VERSE_API_PARAMS:
raise ValueError(
f"Unexpected argument `{param}` is "
"input into SambaVerse API."
)
elif self._url == "https://api.sambanova.ai/v1":
for param in self.model_config_dict:
if param not in SAMBA_CLOUD_API_PARAMS:
raise ValueError(
f"Unexpected argument `{param}` is "
"input into SambaCloud API."
)
else:
raise ValueError(
f"{self._url} is not supported, please check the url to the"
" SambaNova service"
)
[docs]
def run( # type: ignore[misc]
self, messages: List[OpenAIMessage]
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
r"""Runs SambaNova's service.
Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.
Returns:
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
`ChatCompletion` in the non-stream mode, or
`Stream[ChatCompletionChunk]` in the stream mode.
"""
if "tools" in self.model_config_dict:
del self.model_config_dict["tools"]
if self.model_config_dict.get("stream") is True:
return self._run_streaming(messages)
else:
return self._run_non_streaming(messages)
def _run_streaming(
self, messages: List[OpenAIMessage]
) -> Stream[ChatCompletionChunk]:
r"""Handles streaming inference with SambaNova's API.
Args:
messages (List[OpenAIMessage]): A list of messages representing the
chat history in OpenAI API format.
Returns:
Stream[ChatCompletionChunk]: A generator yielding
`ChatCompletionChunk` objects as they are received from the
API.
Raises:
RuntimeError: If the HTTP request fails.
ValueError: If the API doesn't support stream mode.
"""
# Handle SambaNova's Cloud API
if self._url == "https://api.sambanova.ai/v1":
response = self._client.chat.completions.create(
messages=messages,
model=self.model_type,
**self.model_config_dict,
)
# Add AgentOps LLM Event tracking
if LLMEvent:
llm_event = LLMEvent(
thread_id=response.id,
prompt=" ".join(
[message.get("content") for message in messages] # type: ignore[misc]
),
prompt_tokens=response.usage.prompt_tokens, # type: ignore[union-attr]
completion=response.choices[0].message.content,
completion_tokens=response.usage.completion_tokens, # type: ignore[union-attr]
model=self.model_type,
)
record(llm_event)
return response
elif self._url == "https://sambaverse.sambanova.ai/api/predict":
raise ValueError(
"https://sambaverse.sambanova.ai/api/predict doesn't support"
" stream mode"
)
raise RuntimeError(f"Unknown URL: {self._url}")
def _run_non_streaming(
self, messages: List[OpenAIMessage]
) -> ChatCompletion:
r"""Handles non-streaming inference with SambaNova's API.
Args:
messages (List[OpenAIMessage]): A list of messages representing the
message in OpenAI API format.
Returns:
ChatCompletion: A `ChatCompletion` object containing the complete
response from the API.
Raises:
RuntimeError: If the HTTP request fails.
ValueError: If the JSON response cannot be decoded or is missing
expected data.
"""
# Handle SambaNova's Cloud API
if self._url == "https://api.sambanova.ai/v1":
response = self._client.chat.completions.create(
messages=messages,
model=self.model_type,
**self.model_config_dict,
)
# Add AgentOps LLM Event tracking
if LLMEvent:
llm_event = LLMEvent(
thread_id=response.id,
prompt=" ".join(
[message.get("content") for message in messages] # type: ignore[misc]
),
prompt_tokens=response.usage.prompt_tokens, # type: ignore[union-attr]
completion=response.choices[0].message.content,
completion_tokens=response.usage.completion_tokens, # type: ignore[union-attr]
model=self.model_type,
)
record(llm_event)
return response
# Handle SambaNova's Sambaverse API
else:
headers = {
"Content-Type": "application/json",
"key": str(self._api_key),
"modelName": self.model_type,
}
data = {
"instance": json.dumps(
{
"conversation_id": str(uuid.uuid4()),
"messages": messages,
}
),
"params": {
"do_sample": {"type": "bool", "value": "true"},
"max_tokens_to_generate": {
"type": "int",
"value": str(self.model_config_dict.get("max_tokens")),
},
"process_prompt": {"type": "bool", "value": "true"},
"repetition_penalty": {
"type": "float",
"value": str(
self.model_config_dict.get("repetition_penalty")
),
},
"return_token_count_only": {
"type": "bool",
"value": "false",
},
"select_expert": {
"type": "str",
"value": self.model_type.split('/')[1],
},
"stop_sequences": {
"type": "str",
"value": self.model_config_dict.get("stop_sequences"),
},
"temperature": {
"type": "float",
"value": str(
self.model_config_dict.get("temperature")
),
},
"top_k": {
"type": "int",
"value": str(self.model_config_dict.get("top_k")),
},
"top_p": {
"type": "float",
"value": str(self.model_config_dict.get("top_p")),
},
},
}
try:
# Send the request and handle the response
with httpx.Client() as client:
response = client.post(
self._url, # type: ignore[arg-type]
headers=headers,
json=data,
)
raw_text = response.text
# Split the string into two dictionaries
dicts = raw_text.split('}\n{')
# Keep only the last dictionary
last_dict = '{' + dicts[-1]
# Parse the dictionary
last_dict = json.loads(last_dict)
return self._sambaverse_to_openai_response(last_dict) # type: ignore[arg-type]
except httpx.HTTPStatusError:
raise RuntimeError(f"HTTP request failed: {raw_text}")
def _sambaverse_to_openai_response(
self, samba_response: Dict[str, Any]
) -> ChatCompletion:
r"""Converts SambaVerse API response into an OpenAI-compatible
response.
Args:
samba_response (Dict[str, Any]): A dictionary representing
responses from the SambaVerse API.
Returns:
ChatCompletion: A `ChatCompletion` object constructed from the
aggregated response data.
"""
choices = [
dict(
index=0,
message={
"role": 'assistant',
"content": samba_response['result']['responses'][0][
'completion'
],
},
finish_reason=samba_response['result']['responses'][0][
'stop_reason'
],
)
]
obj = ChatCompletion.construct(
id=None,
choices=choices,
created=int(time.time()),
model=self.model_type,
object="chat.completion",
# SambaVerse API only provide `total_tokens`
usage=CompletionUsage(
completion_tokens=0,
prompt_tokens=0,
total_tokens=int(
samba_response['result']['responses'][0][
'total_tokens_count'
]
),
),
)
return obj
@property
def stream(self) -> bool:
r"""Returns whether the model is in stream mode, which sends partial
results each time.
Returns:
bool: Whether the model is in stream mode.
"""
return self.model_config_dict.get('stream', False)