diff --git a/.env.example b/.env.example index 72f3df6..953d9f4 100644 --- a/.env.example +++ b/.env.example @@ -2,8 +2,24 @@ SECRET_KEY=your-secret-key-here-please-change-this FLASK_ENV=production -# OpenAI API配置(可选,用于AI润色功能) -OPENAI_API_KEY= +# AI服务配置(可选,用于AI润色功能) +# 支持多个AI服务商:OpenAI、Claude、Gemini等 + +# 方式1:使用统一配置(推荐) +AI_MODEL=gpt-3.5-turbo # 模型名称,例如: gpt-4, claude-3-sonnet-20240229, gemini-pro +AI_API_KEY= # API密钥 +# AI_API_BASE= # 可选:自定义API端点 +# AI_TEMPERATURE=0.7 # 可选:温度参数(0-1) +# AI_MAX_TOKENS=500 # 可选:最大token数 + +# 方式2:使用OpenAI配置(兼容旧版本) +# OPENAI_API_KEY= + +# 常用模型示例: +# OpenAI: gpt-3.5-turbo, gpt-4, gpt-4-turbo +# Claude: claude-3-opus-20240229, claude-3-sonnet-20240229, claude-3-haiku-20240307 +# Gemini: gemini-pro, gemini-1.5-pro +# 更多模型请参考: https://docs.litellm.ai/docs/providers # 默认用户配置 # Docker首次启动时会自动创建此用户 diff --git a/backend/ai_service.py b/backend/ai_service.py index 1c06ebe..0fd6546 100644 --- a/backend/ai_service.py +++ b/backend/ai_service.py @@ -1,4 +1,4 @@ -import openai +from litellm import completion import os from dotenv import load_dotenv @@ -6,54 +6,104 @@ load_dotenv() class AIService: def __init__(self): - # 设置OpenAI API密钥 - self.api_key = os.getenv('OPENAI_API_KEY') - # 延迟初始化客户端,避免在没有API密钥时出错 - self.client = None - + """ + 初始化AI服务 + 支持多个AI服务商: + - OpenAI (gpt-3.5-turbo, gpt-4等) + - Anthropic Claude (claude-3-opus, claude-3-sonnet等) + - Google Gemini (gemini-pro等) + - 其他LiteLLM支持的所有模型 + """ + # 获取AI配置 + self.model = os.getenv('AI_MODEL', 'gpt-3.5-turbo') + self.api_key = os.getenv('AI_API_KEY', os.getenv('OPENAI_API_KEY')) # 兼容旧配置 + self.api_base = os.getenv('AI_API_BASE') # 可选:自定义API端点 + self.temperature = float(os.getenv('AI_TEMPERATURE', '0.7')) + self.max_tokens = int(os.getenv('AI_MAX_TOKENS', '500')) + + # 设置环境变量供LiteLLM使用 + if self.api_key: + # 根据模型类型设置相应的环境变量 + if self.model.startswith('gpt-'): + os.environ['OPENAI_API_KEY'] = self.api_key + elif self.model.startswith('claude-'): + os.environ['ANTHROPIC_API_KEY'] = self.api_key + elif self.model.startswith('gemini-'): + os.environ['GEMINI_API_KEY'] = self.api_key + # LiteLLM会自动处理其他模型的API密钥 + def polish_description(self, description): """ 使用AI润色任务描述 + 支持多个AI服务商 """ if not description or not description.strip(): return description - + # 检查API密钥 - if not self.api_key or self.api_key == 'your_openai_api_key_here': - print("AI润色功能需要配置OpenAI API密钥") + if not self.api_key or self.api_key == 'your_api_key_here': + print(f"AI润色功能需要配置 AI_API_KEY (当前模型: {self.model})") return description - + try: - # 使用旧版本OpenAI API - openai.api_key = self.api_key - - response = openai.ChatCompletion.create( - model="gpt-3.5-turbo", - messages=[ + # 构建请求参数 + kwargs = { + "model": self.model, + "messages": [ { - "role": "system", + "role": "system", "content": "你是一个专业的工作任务描述润色助手。请将用户提供的工作任务描述润色得更加专业、清晰、具体。保持原意不变,但让描述更加规范和易于理解。" }, { - "role": "user", + "role": "user", "content": f"请润色以下工作任务描述:\n\n{description}" } ], - max_tokens=500, - temperature=0.7 - ) - + "max_tokens": self.max_tokens, + "temperature": self.temperature + } + + # 如果设置了自定义API端点 + if self.api_base: + kwargs["api_base"] = self.api_base + + # 使用LiteLLM统一接口调用 + response = completion(**kwargs) + polished = response.choices[0].message.content.strip() return polished - + except Exception as e: - print(f"AI润色失败: {e}") + print(f"AI润色失败 (模型: {self.model}): {e}") # 如果AI服务失败,返回原始描述 return description - + def is_available(self): """检查AI服务是否可用""" - return bool(self.api_key and self.api_key != 'your_openai_api_key_here') + return bool(self.api_key and self.api_key != 'your_api_key_here') + + def get_model_info(self): + """获取当前使用的模型信息""" + return { + "model": self.model, + "provider": self._get_provider(), + "available": self.is_available() + } + + def _get_provider(self): + """根据模型名称推断服务商""" + if self.model.startswith('gpt-'): + return 'OpenAI' + elif self.model.startswith('claude-'): + return 'Anthropic' + elif self.model.startswith('gemini-'): + return 'Google' + elif self.model.startswith('command'): + return 'Cohere' + elif self.model.startswith('llama'): + return 'Meta/Together' + else: + return 'Unknown' # 创建全局AI服务实例 ai_service = AIService() diff --git a/backend/requirements.txt b/backend/requirements.txt index f091b67..4e9cb94 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -2,6 +2,6 @@ fastapi==0.109.0 uvicorn[standard]==0.27.0 SQLAlchemy==2.0.21 python-dotenv==1.0.0 -openai==0.28.1 +litellm==1.50.3 werkzeug==3.0.1 itsdangerous