Source code for camel.models.reward.skywork_model
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Dict, List, Optional, Union
import torch
from camel.models.reward import BaseRewardModel
from camel.types import ModelType
[docs]
class SkyworkRewardModel(BaseRewardModel):
r"""Reward model based on the transformers, it will download the model
from huggingface.
Args:
model_type (Union[ModelType, str]): Model for which a backend is
created.
api_key (Optional[str], optional): Not used. (default: :obj:`None`)
url (Optional[str], optional): Not used. (default: :obj:`None`)
device_map (Optional[str], optional): choose the device map.
(default: :obj:`auto`)
attn_implementation (Optional[str], optional): choose the attention
implementation. (default: :obj:`flash_attention_2`)
offload_folder (Optional[str], optional): choose the offload folder.
(default: :obj:`offload`)
"""
def __init__(
self,
model_type: Union[ModelType, str],
api_key: Optional[str] = None,
url: Optional[str] = None,
device_map: Optional[str] = "auto",
attn_implementation: Optional[str] = "flash_attention_2",
offload_folder: Optional[str] = "offload",
) -> None:
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
)
super().__init__(model_type, api_key, url)
self._client = AutoModelForSequenceClassification.from_pretrained(
model_type,
torch_dtype=torch.bfloat16,
device_map=device_map,
attn_implementation=attn_implementation,
offload_folder=offload_folder,
num_labels=1,
)
self._tokenizer = AutoTokenizer.from_pretrained(model_type)
[docs]
def evaluate(self, messages: List[Dict[str, str]]) -> Dict[str, float]:
r"""Evaluate the messages using the Skywork model.
Args:
messages (List[Dict[str, str]]): A list of messages.
Returns:
ChatCompletion: A ChatCompletion object with the scores.
"""
inputs = self._tokenizer.apply_chat_template(
messages,
tokenize=True,
return_tensors="pt",
)
with torch.no_grad():
score = self._client(inputs).logits[0][0].item()
return {"Score": score}
[docs]
def get_scores_types(self) -> List[str]:
r"""get the scores types
Returns:
List[str]: list of scores types
"""
return ["Score"]