from fastapi import APIRouter, Depends, HTTPException, Request from sqlalchemy.orm import Session from datetime import datetime, timedelta from database import get_db from models import Task, TimeRecord, User from ai_service import ai_service from pydantic import BaseModel from typing import Optional, List from sqlalchemy import func api = APIRouter() # Pydantic 模型用于请求验证 class LoginRequest(BaseModel): username: str password: str class RegisterRequest(BaseModel): username: str password: str class TaskCreate(BaseModel): title: str description: Optional[str] = '' status: Optional[str] = 'pending' class TaskUpdate(BaseModel): title: Optional[str] = None description: Optional[str] = None status: Optional[str] = None class TimerRequest(BaseModel): task_id: int class TimerBatchRequest(BaseModel): task_ids: List[int] # 认证API @api.post('/auth/login') async def login(request: Request, data: LoginRequest, db: Session = Depends(get_db)): """用户登录""" user = db.query(User).filter_by(username=data.username).first() if user and user.check_password(data.password): # 登录成功,设置session request.session['user_id'] = user.id request.session['username'] = user.username return { 'message': '登录成功', 'user': user.to_dict() } else: raise HTTPException(status_code=401, detail='用户名或密码错误') @api.post('/auth/logout') async def logout(request: Request): """用户登出""" request.session.clear() return {'message': '登出成功'} @api.get('/auth/check') async def check_auth(request: Request, db: Session = Depends(get_db)): """检查登录状态""" user_id = request.session.get('user_id') if user_id: user = db.query(User).get(user_id) if user: return { 'authenticated': True, 'user': user.to_dict() } raise HTTPException(status_code=401, detail='未登录') @api.post('/auth/register', status_code=201) async def register(data: RegisterRequest, db: Session = Depends(get_db)): """用户注册(可选,用于创建初始用户)""" # 检查用户是否已存在 if db.query(User).filter_by(username=data.username).first(): raise HTTPException(status_code=400, detail='用户名已存在') # 创建新用户 user = User(username=data.username) user.set_password(data.password) db.add(user) db.commit() db.refresh(user) return { 'message': '注册成功', 'user': user.to_dict() } # 任务管理API @api.get('/tasks') async def get_tasks(db: Session = Depends(get_db)): """获取所有任务""" tasks = db.query(Task).order_by(Task.created_at.desc()).all() return [task.to_dict(db) for task in tasks] @api.post('/tasks', status_code=201) async def create_task(data: TaskCreate, db: Session = Depends(get_db)): """创建新任务""" task = Task( title=data.title, description=data.description, status=data.status ) db.add(task) db.commit() db.refresh(task) return task.to_dict(db) @api.put('/tasks/{task_id}') async def update_task(task_id: int, data: TaskUpdate, db: Session = Depends(get_db)): """更新任务""" task = db.query(Task).filter(Task.id == task_id).first() if not task: raise HTTPException(status_code=404, detail='任务不存在') if data.title is not None: task.title = data.title if data.description is not None: task.description = data.description if data.status is not None: task.status = data.status task.updated_at = datetime.utcnow() db.commit() db.refresh(task) return task.to_dict(db) @api.delete('/tasks/{task_id}') async def delete_task(task_id: int, db: Session = Depends(get_db)): """删除任务""" task = db.query(Task).filter(Task.id == task_id).first() if not task: raise HTTPException(status_code=404, detail='任务不存在') db.delete(task) db.commit() return {'message': '任务删除成功'} @api.post('/tasks/{task_id}/polish') async def polish_task_description(task_id: int, db: Session = Depends(get_db)): """AI润色任务描述""" task = db.query(Task).filter(Task.id == task_id).first() if not task: raise HTTPException(status_code=404, detail='任务不存在') if not ai_service.is_available(): raise HTTPException(status_code=500, detail='AI服务不可用,请检查API密钥配置') if not task.description: raise HTTPException(status_code=400, detail='任务描述为空,无法润色') polished_description = ai_service.polish_description(task.description) task.polished_description = polished_description task.updated_at = datetime.utcnow() db.commit() return { 'original': task.description, 'polished': polished_description } # 计时器API @api.post('/timer/start') async def start_timer(data: TimerRequest, db: Session = Depends(get_db)): """开始计时""" task = db.query(Task).filter(Task.id == data.task_id).first() if not task: raise HTTPException(status_code=404, detail='任务不存在') # 检查是否已有进行中的计时 active_record = db.query(TimeRecord).filter_by( task_id=data.task_id, end_time=None ).first() if active_record: raise HTTPException(status_code=400, detail='该任务已在计时中') # 创建新的时间记录 time_record = TimeRecord( task_id=data.task_id, start_time=datetime.utcnow() ) db.add(time_record) task.status = 'in_progress' db.commit() db.refresh(time_record) return time_record.to_dict() @api.post('/timer/stop') async def stop_timer(data: TimerRequest, db: Session = Depends(get_db)): """停止计时""" # 查找进行中的时间记录 time_record = db.query(TimeRecord).filter_by( task_id=data.task_id, end_time=None ).first() if not time_record: raise HTTPException(status_code=400, detail='没有找到进行中的计时') # 结束计时 time_record.end_time = datetime.utcnow() time_record.calculate_duration() # 更新任务状态 task = db.query(Task).get(data.task_id) if task: task.status = 'pending' # 或者根据业务逻辑设置其他状态 db.commit() db.refresh(time_record) return time_record.to_dict() @api.get('/timer/status/{task_id}') async def get_timer_status(task_id: int, db: Session = Depends(get_db)): """获取任务计时状态""" active_record = db.query(TimeRecord).filter_by( task_id=task_id, end_time=None ).first() if active_record: return { 'is_running': True, 'start_time': active_record.start_time.isoformat(), 'duration': int((datetime.utcnow() - active_record.start_time).total_seconds()) } else: return {'is_running': False} @api.post('/timer/status/batch') async def get_timer_status_batch(data: TimerBatchRequest, db: Session = Depends(get_db)): """批量获取任务计时状态""" if not data.task_ids: raise HTTPException(status_code=400, detail='任务ID列表不能为空') # 初始化默认状态 statuses = {int(task_id): {'is_running': False} for task_id in data.task_ids} active_records = db.query(TimeRecord).filter( TimeRecord.task_id.in_(data.task_ids), TimeRecord.end_time.is_(None) ).all() now = datetime.utcnow() for record in active_records: statuses[record.task_id] = { 'is_running': True, 'start_time': record.start_time.isoformat(), 'duration': int((now - record.start_time).total_seconds()) } return statuses # 统计报表API @api.get('/reports/daily') async def get_daily_report(date: Optional[str] = None, db: Session = Depends(get_db)): """获取日报表""" if date: try: target_date = datetime.strptime(date, '%Y-%m-%d').date() except ValueError: raise HTTPException(status_code=400, detail='日期格式错误,请使用YYYY-MM-DD格式') else: target_date = datetime.now().date() # 获取指定日期的所有时间记录 start_datetime = datetime.combine(target_date, datetime.min.time()) end_datetime = datetime.combine(target_date, datetime.max.time()) records = db.query(TimeRecord).filter( TimeRecord.start_time >= start_datetime, TimeRecord.start_time <= end_datetime, TimeRecord.end_time.isnot(None) ).all() # 按任务分组统计 task_stats = {} total_time = 0 for record in records: task_id = record.task_id if task_id not in task_stats: task = db.query(Task).get(task_id) task_stats[task_id] = { 'task': task.to_dict(db) if task else None, 'total_duration': 0, 'records': [] } task_stats[task_id]['total_duration'] += record.duration or 0 task_stats[task_id]['records'].append(record.to_dict()) total_time += record.duration or 0 return { 'date': target_date.isoformat(), 'total_time': total_time, 'task_stats': list(task_stats.values()) } @api.get('/reports/summary') async def get_summary_report(days: int = 7, db: Session = Depends(get_db)): """获取汇总报表""" end_date = datetime.now().date() start_date = end_date - timedelta(days=days-1) start_datetime = datetime.combine(start_date, datetime.min.time()) end_datetime = datetime.combine(end_date, datetime.max.time()) # 获取时间范围内的记录 records = db.query(TimeRecord).filter( TimeRecord.start_time >= start_datetime, TimeRecord.start_time <= end_datetime, TimeRecord.end_time.isnot(None) ).all() # 按日期分组 daily_stats = {} for record in records: date_key = record.start_time.date().isoformat() if date_key not in daily_stats: daily_stats[date_key] = { 'date': date_key, 'total_time': 0, 'tasks': {} } task_id = record.task_id if task_id not in daily_stats[date_key]['tasks']: task = db.query(Task).get(task_id) daily_stats[date_key]['tasks'][task_id] = { 'task_title': task.title if task else f'任务{task_id}', 'total_time': 0 } duration = record.duration or 0 daily_stats[date_key]['total_time'] += duration daily_stats[date_key]['tasks'][task_id]['total_time'] += duration return { 'period': f'{start_date.isoformat()} 至 {end_date.isoformat()}', 'daily_stats': list(daily_stats.values()) } # 时间段历史API @api.get('/tasks/{task_id}/time-history') async def get_task_time_history( task_id: int, days: int = 30, page: int = 1, per_page: int = 20, db: Session = Depends(get_db) ): """获取任务的时间段历史""" task = db.query(Task).filter(Task.id == task_id).first() if not task: raise HTTPException(status_code=404, detail='任务不存在') # 计算日期范围 end_date = datetime.now().date() start_date = end_date - timedelta(days=days-1) start_datetime = datetime.combine(start_date, datetime.min.time()) end_datetime = datetime.combine(end_date, datetime.max.time()) # 查询时间记录 query = db.query(TimeRecord).filter( TimeRecord.task_id == task_id, TimeRecord.start_time >= start_datetime, TimeRecord.start_time <= end_datetime, TimeRecord.end_time.isnot(None) ).order_by(TimeRecord.start_time.desc()) # 手动分页 total = query.count() offset = (page - 1) * per_page records = query.offset(offset).limit(per_page).all() # 按日期分组 daily_segments = {} for record in records: date_key = record.start_time.date().isoformat() if date_key not in daily_segments: daily_segments[date_key] = { 'date': date_key, 'total_duration': 0, 'segments': [] } daily_segments[date_key]['total_duration'] += record.duration or 0 daily_segments[date_key]['segments'].append(record.to_dict()) pages = (total + per_page - 1) // per_page return { 'task': task.to_dict(db), 'period': f'{start_date.isoformat()} 至 {end_date.isoformat()}', 'daily_segments': list(daily_segments.values()), 'pagination': { 'page': page, 'pages': pages, 'per_page': per_page, 'total': total, 'has_next': page < pages, 'has_prev': page > 1 } } @api.get('/time-history') async def get_all_time_history( days: int = 7, task_id: Optional[int] = None, db: Session = Depends(get_db) ): """获取所有任务的时间段历史""" # 计算日期范围 end_date = datetime.now().date() start_date = end_date - timedelta(days=days-1) start_datetime = datetime.combine(start_date, datetime.min.time()) end_datetime = datetime.combine(end_date, datetime.max.time()) # 构建查询 query = db.query(TimeRecord).filter( TimeRecord.start_time >= start_datetime, TimeRecord.start_time <= end_datetime, TimeRecord.end_time.isnot(None) ) if task_id: query = query.filter(TimeRecord.task_id == task_id) # 按开始时间排序 records = query.order_by(TimeRecord.start_time.desc()).all() # 按日期和任务分组 daily_tasks = {} for record in records: date_key = record.start_time.date().isoformat() task_id_key = record.task_id if date_key not in daily_tasks: daily_tasks[date_key] = {} if task_id_key not in daily_tasks[date_key]: task = db.query(Task).get(task_id_key) daily_tasks[date_key][task_id_key] = { 'task': task.to_dict(db) if task else None, 'total_duration': 0, 'segments': [] } daily_tasks[date_key][task_id_key]['total_duration'] += record.duration or 0 daily_tasks[date_key][task_id_key]['segments'].append(record.to_dict()) # 转换为列表格式 result = [] for date, tasks in daily_tasks.items(): day_data = { 'date': date, 'total_time': sum(task['total_duration'] for task in tasks.values()), 'tasks': list(tasks.values()) } result.append(day_data) # 按日期排序(最新的在前) result.sort(key=lambda x: x['date'], reverse=True) return { 'period': f'{start_date.isoformat()} 至 {end_date.isoformat()}', 'daily_tasks': result }