121 lines
4.7 KiB
Python
121 lines
4.7 KiB
Python
from litellm import completion
|
||
import os
|
||
from dotenv import load_dotenv
|
||
|
||
load_dotenv()
|
||
|
||
class AIService:
|
||
def __init__(self):
|
||
"""
|
||
初始化AI服务
|
||
支持多个AI服务商:
|
||
- OpenAI (gpt-3.5-turbo, gpt-4等)
|
||
- Anthropic Claude (claude-3-opus, claude-3-sonnet等)
|
||
- Google Gemini (gemini-pro等)
|
||
- 自定义端点(兼容OpenAI格式的API)
|
||
- 其他LiteLLM支持的所有模型
|
||
"""
|
||
# 获取AI配置
|
||
self.model = os.getenv('AI_MODEL', 'gpt-3.5-turbo')
|
||
self.api_key = os.getenv('AI_API_KEY', os.getenv('OPENAI_API_KEY')) # 兼容旧配置
|
||
self.api_base = os.getenv('AI_API_BASE') # 可选:自定义API端点
|
||
self.custom_llm_provider = os.getenv('AI_CUSTOM_PROVIDER') # 可选:自定义端点的API格式(如 'openai')
|
||
self.temperature = float(os.getenv('AI_TEMPERATURE', '0.7'))
|
||
self.max_tokens = int(os.getenv('AI_MAX_TOKENS', '500'))
|
||
|
||
# 设置环境变量供LiteLLM使用
|
||
if self.api_key:
|
||
# 如果使用自定义端点,优先使用OPENAI_API_KEY(兼容OpenAI格式的端点)
|
||
if self.api_base and self.custom_llm_provider == 'openai':
|
||
os.environ['OPENAI_API_KEY'] = self.api_key
|
||
# 根据模型类型设置相应的环境变量
|
||
elif self.model.startswith('gpt-'):
|
||
os.environ['OPENAI_API_KEY'] = self.api_key
|
||
elif self.model.startswith('claude-'):
|
||
os.environ['ANTHROPIC_API_KEY'] = self.api_key
|
||
elif self.model.startswith('gemini-'):
|
||
# 如果没有设置自定义provider,使用默认的Gemini API Key
|
||
if not self.custom_llm_provider:
|
||
os.environ['GEMINI_API_KEY'] = self.api_key
|
||
# LiteLLM会自动处理其他模型的API密钥
|
||
|
||
def polish_description(self, description):
|
||
"""
|
||
使用AI润色任务描述
|
||
支持多个AI服务商
|
||
"""
|
||
if not description or not description.strip():
|
||
return description
|
||
|
||
# 检查API密钥
|
||
if not self.api_key or self.api_key == 'your_api_key_here':
|
||
print(f"AI润色功能需要配置 AI_API_KEY (当前模型: {self.model})")
|
||
return description
|
||
|
||
try:
|
||
# 构建请求参数
|
||
kwargs = {
|
||
"model": self.model,
|
||
"messages": [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个专业的工作任务描述润色助手。请将用户提供的工作任务描述润色得更加专业、清晰、具体。保持原意不变,但让描述更加规范和易于理解。"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"请润色以下工作任务描述:\n\n{description}"
|
||
}
|
||
],
|
||
"max_tokens": self.max_tokens,
|
||
"temperature": self.temperature
|
||
}
|
||
|
||
# 如果设置了自定义API端点
|
||
if self.api_base:
|
||
kwargs["api_base"] = self.api_base
|
||
|
||
# 如果设置了自定义provider(用于兼容OpenAI格式的自定义端点)
|
||
if self.custom_llm_provider:
|
||
kwargs["custom_llm_provider"] = self.custom_llm_provider
|
||
|
||
# 使用LiteLLM统一接口调用
|
||
response = completion(**kwargs)
|
||
|
||
polished = response.choices[0].message.content.strip()
|
||
return polished
|
||
|
||
except Exception as e:
|
||
print(f"AI润色失败 (模型: {self.model}): {e}")
|
||
# 如果AI服务失败,返回原始描述
|
||
return description
|
||
|
||
def is_available(self):
|
||
"""检查AI服务是否可用"""
|
||
return bool(self.api_key and self.api_key != 'your_api_key_here')
|
||
|
||
def get_model_info(self):
|
||
"""获取当前使用的模型信息"""
|
||
return {
|
||
"model": self.model,
|
||
"provider": self._get_provider(),
|
||
"available": self.is_available()
|
||
}
|
||
|
||
def _get_provider(self):
|
||
"""根据模型名称推断服务商"""
|
||
if self.model.startswith('gpt-'):
|
||
return 'OpenAI'
|
||
elif self.model.startswith('claude-'):
|
||
return 'Anthropic'
|
||
elif self.model.startswith('gemini-'):
|
||
return 'Google'
|
||
elif self.model.startswith('command'):
|
||
return 'Cohere'
|
||
elif self.model.startswith('llama'):
|
||
return 'Meta/Together'
|
||
else:
|
||
return 'Unknown'
|
||
|
||
# 创建全局AI服务实例
|
||
ai_service = AIService()
|