主要改动: - 更新依赖:使用 fastapi、uvicorn 替代 Flask - 数据库层:从 Flask-SQLAlchemy 迁移到纯 SQLAlchemy - 新增 database.py 管理数据库连接和会话 - 路由层:从 Blueprint 迁移到 APIRouter,添加 Pydantic 模型验证 - 应用层:使用 FastAPI 中间件替代 Flask 插件 - 启动方式:使用 uvicorn 替代 Flask 开发服务器 - 更新 Docker 配置以支持 FastAPI 优势: - 更高的性能和异步支持 - 自动生成 OpenAPI 文档 - 更好的类型安全和数据验证 - 所有 API 端点保持向后兼容 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
487 lines
15 KiB
Python
487 lines
15 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, Request
|
||
from sqlalchemy.orm import Session
|
||
from datetime import datetime, timedelta
|
||
from database import get_db
|
||
from models import Task, TimeRecord, User
|
||
from ai_service import ai_service
|
||
from pydantic import BaseModel
|
||
from typing import Optional, List
|
||
from sqlalchemy import func
|
||
|
||
api = APIRouter()
|
||
|
||
# Pydantic 模型用于请求验证
|
||
class LoginRequest(BaseModel):
|
||
username: str
|
||
password: str
|
||
|
||
class RegisterRequest(BaseModel):
|
||
username: str
|
||
password: str
|
||
|
||
class TaskCreate(BaseModel):
|
||
title: str
|
||
description: Optional[str] = ''
|
||
status: Optional[str] = 'pending'
|
||
|
||
class TaskUpdate(BaseModel):
|
||
title: Optional[str] = None
|
||
description: Optional[str] = None
|
||
status: Optional[str] = None
|
||
|
||
class TimerRequest(BaseModel):
|
||
task_id: int
|
||
|
||
class TimerBatchRequest(BaseModel):
|
||
task_ids: List[int]
|
||
|
||
# 认证API
|
||
@api.post('/auth/login')
|
||
async def login(request: Request, data: LoginRequest, db: Session = Depends(get_db)):
|
||
"""用户登录"""
|
||
user = db.query(User).filter_by(username=data.username).first()
|
||
|
||
if user and user.check_password(data.password):
|
||
# 登录成功,设置session
|
||
request.session['user_id'] = user.id
|
||
request.session['username'] = user.username
|
||
return {
|
||
'message': '登录成功',
|
||
'user': user.to_dict()
|
||
}
|
||
else:
|
||
raise HTTPException(status_code=401, detail='用户名或密码错误')
|
||
|
||
@api.post('/auth/logout')
|
||
async def logout(request: Request):
|
||
"""用户登出"""
|
||
request.session.clear()
|
||
return {'message': '登出成功'}
|
||
|
||
@api.get('/auth/check')
|
||
async def check_auth(request: Request, db: Session = Depends(get_db)):
|
||
"""检查登录状态"""
|
||
user_id = request.session.get('user_id')
|
||
if user_id:
|
||
user = db.query(User).get(user_id)
|
||
if user:
|
||
return {
|
||
'authenticated': True,
|
||
'user': user.to_dict()
|
||
}
|
||
raise HTTPException(status_code=401, detail='未登录')
|
||
|
||
@api.post('/auth/register', status_code=201)
|
||
async def register(data: RegisterRequest, db: Session = Depends(get_db)):
|
||
"""用户注册(可选,用于创建初始用户)"""
|
||
# 检查用户是否已存在
|
||
if db.query(User).filter_by(username=data.username).first():
|
||
raise HTTPException(status_code=400, detail='用户名已存在')
|
||
|
||
# 创建新用户
|
||
user = User(username=data.username)
|
||
user.set_password(data.password)
|
||
|
||
db.add(user)
|
||
db.commit()
|
||
db.refresh(user)
|
||
|
||
return {
|
||
'message': '注册成功',
|
||
'user': user.to_dict()
|
||
}
|
||
|
||
# 任务管理API
|
||
@api.get('/tasks')
|
||
async def get_tasks(db: Session = Depends(get_db)):
|
||
"""获取所有任务"""
|
||
tasks = db.query(Task).order_by(Task.created_at.desc()).all()
|
||
return [task.to_dict(db) for task in tasks]
|
||
|
||
@api.post('/tasks', status_code=201)
|
||
async def create_task(data: TaskCreate, db: Session = Depends(get_db)):
|
||
"""创建新任务"""
|
||
task = Task(
|
||
title=data.title,
|
||
description=data.description,
|
||
status=data.status
|
||
)
|
||
|
||
db.add(task)
|
||
db.commit()
|
||
db.refresh(task)
|
||
|
||
return task.to_dict(db)
|
||
|
||
@api.put('/tasks/{task_id}')
|
||
async def update_task(task_id: int, data: TaskUpdate, db: Session = Depends(get_db)):
|
||
"""更新任务"""
|
||
task = db.query(Task).filter(Task.id == task_id).first()
|
||
if not task:
|
||
raise HTTPException(status_code=404, detail='任务不存在')
|
||
|
||
if data.title is not None:
|
||
task.title = data.title
|
||
if data.description is not None:
|
||
task.description = data.description
|
||
if data.status is not None:
|
||
task.status = data.status
|
||
|
||
task.updated_at = datetime.utcnow()
|
||
db.commit()
|
||
db.refresh(task)
|
||
|
||
return task.to_dict(db)
|
||
|
||
@api.delete('/tasks/{task_id}')
|
||
async def delete_task(task_id: int, db: Session = Depends(get_db)):
|
||
"""删除任务"""
|
||
task = db.query(Task).filter(Task.id == task_id).first()
|
||
if not task:
|
||
raise HTTPException(status_code=404, detail='任务不存在')
|
||
|
||
db.delete(task)
|
||
db.commit()
|
||
|
||
return {'message': '任务删除成功'}
|
||
|
||
@api.post('/tasks/{task_id}/polish')
|
||
async def polish_task_description(task_id: int, db: Session = Depends(get_db)):
|
||
"""AI润色任务描述"""
|
||
task = db.query(Task).filter(Task.id == task_id).first()
|
||
if not task:
|
||
raise HTTPException(status_code=404, detail='任务不存在')
|
||
|
||
if not ai_service.is_available():
|
||
raise HTTPException(status_code=500, detail='AI服务不可用,请检查API密钥配置')
|
||
|
||
if not task.description:
|
||
raise HTTPException(status_code=400, detail='任务描述为空,无法润色')
|
||
|
||
polished_description = ai_service.polish_description(task.description)
|
||
task.polished_description = polished_description
|
||
task.updated_at = datetime.utcnow()
|
||
|
||
db.commit()
|
||
|
||
return {
|
||
'original': task.description,
|
||
'polished': polished_description
|
||
}
|
||
|
||
# 计时器API
|
||
@api.post('/timer/start')
|
||
async def start_timer(data: TimerRequest, db: Session = Depends(get_db)):
|
||
"""开始计时"""
|
||
task = db.query(Task).filter(Task.id == data.task_id).first()
|
||
if not task:
|
||
raise HTTPException(status_code=404, detail='任务不存在')
|
||
|
||
# 检查是否已有进行中的计时
|
||
active_record = db.query(TimeRecord).filter_by(
|
||
task_id=data.task_id,
|
||
end_time=None
|
||
).first()
|
||
|
||
if active_record:
|
||
raise HTTPException(status_code=400, detail='该任务已在计时中')
|
||
|
||
# 创建新的时间记录
|
||
time_record = TimeRecord(
|
||
task_id=data.task_id,
|
||
start_time=datetime.utcnow()
|
||
)
|
||
|
||
db.add(time_record)
|
||
task.status = 'in_progress'
|
||
db.commit()
|
||
db.refresh(time_record)
|
||
|
||
return time_record.to_dict()
|
||
|
||
@api.post('/timer/stop')
|
||
async def stop_timer(data: TimerRequest, db: Session = Depends(get_db)):
|
||
"""停止计时"""
|
||
# 查找进行中的时间记录
|
||
time_record = db.query(TimeRecord).filter_by(
|
||
task_id=data.task_id,
|
||
end_time=None
|
||
).first()
|
||
|
||
if not time_record:
|
||
raise HTTPException(status_code=400, detail='没有找到进行中的计时')
|
||
|
||
# 结束计时
|
||
time_record.end_time = datetime.utcnow()
|
||
time_record.calculate_duration()
|
||
|
||
# 更新任务状态
|
||
task = db.query(Task).get(data.task_id)
|
||
if task:
|
||
task.status = 'pending' # 或者根据业务逻辑设置其他状态
|
||
|
||
db.commit()
|
||
db.refresh(time_record)
|
||
|
||
return time_record.to_dict()
|
||
|
||
@api.get('/timer/status/{task_id}')
|
||
async def get_timer_status(task_id: int, db: Session = Depends(get_db)):
|
||
"""获取任务计时状态"""
|
||
active_record = db.query(TimeRecord).filter_by(
|
||
task_id=task_id,
|
||
end_time=None
|
||
).first()
|
||
|
||
if active_record:
|
||
return {
|
||
'is_running': True,
|
||
'start_time': active_record.start_time.isoformat(),
|
||
'duration': int((datetime.utcnow() - active_record.start_time).total_seconds())
|
||
}
|
||
else:
|
||
return {'is_running': False}
|
||
|
||
@api.post('/timer/status/batch')
|
||
async def get_timer_status_batch(data: TimerBatchRequest, db: Session = Depends(get_db)):
|
||
"""批量获取任务计时状态"""
|
||
if not data.task_ids:
|
||
raise HTTPException(status_code=400, detail='任务ID列表不能为空')
|
||
|
||
# 初始化默认状态
|
||
statuses = {int(task_id): {'is_running': False} for task_id in data.task_ids}
|
||
|
||
active_records = db.query(TimeRecord).filter(
|
||
TimeRecord.task_id.in_(data.task_ids),
|
||
TimeRecord.end_time.is_(None)
|
||
).all()
|
||
|
||
now = datetime.utcnow()
|
||
for record in active_records:
|
||
statuses[record.task_id] = {
|
||
'is_running': True,
|
||
'start_time': record.start_time.isoformat(),
|
||
'duration': int((now - record.start_time).total_seconds())
|
||
}
|
||
|
||
return statuses
|
||
|
||
# 统计报表API
|
||
@api.get('/reports/daily')
|
||
async def get_daily_report(date: Optional[str] = None, db: Session = Depends(get_db)):
|
||
"""获取日报表"""
|
||
if date:
|
||
try:
|
||
target_date = datetime.strptime(date, '%Y-%m-%d').date()
|
||
except ValueError:
|
||
raise HTTPException(status_code=400, detail='日期格式错误,请使用YYYY-MM-DD格式')
|
||
else:
|
||
target_date = datetime.now().date()
|
||
|
||
# 获取指定日期的所有时间记录
|
||
start_datetime = datetime.combine(target_date, datetime.min.time())
|
||
end_datetime = datetime.combine(target_date, datetime.max.time())
|
||
|
||
records = db.query(TimeRecord).filter(
|
||
TimeRecord.start_time >= start_datetime,
|
||
TimeRecord.start_time <= end_datetime,
|
||
TimeRecord.end_time.isnot(None)
|
||
).all()
|
||
|
||
# 按任务分组统计
|
||
task_stats = {}
|
||
total_time = 0
|
||
|
||
for record in records:
|
||
task_id = record.task_id
|
||
if task_id not in task_stats:
|
||
task = db.query(Task).get(task_id)
|
||
task_stats[task_id] = {
|
||
'task': task.to_dict(db) if task else None,
|
||
'total_duration': 0,
|
||
'records': []
|
||
}
|
||
|
||
task_stats[task_id]['total_duration'] += record.duration or 0
|
||
task_stats[task_id]['records'].append(record.to_dict())
|
||
total_time += record.duration or 0
|
||
|
||
return {
|
||
'date': target_date.isoformat(),
|
||
'total_time': total_time,
|
||
'task_stats': list(task_stats.values())
|
||
}
|
||
|
||
@api.get('/reports/summary')
|
||
async def get_summary_report(days: int = 7, db: Session = Depends(get_db)):
|
||
"""获取汇总报表"""
|
||
end_date = datetime.now().date()
|
||
start_date = end_date - timedelta(days=days-1)
|
||
|
||
start_datetime = datetime.combine(start_date, datetime.min.time())
|
||
end_datetime = datetime.combine(end_date, datetime.max.time())
|
||
|
||
# 获取时间范围内的记录
|
||
records = db.query(TimeRecord).filter(
|
||
TimeRecord.start_time >= start_datetime,
|
||
TimeRecord.start_time <= end_datetime,
|
||
TimeRecord.end_time.isnot(None)
|
||
).all()
|
||
|
||
# 按日期分组
|
||
daily_stats = {}
|
||
for record in records:
|
||
date_key = record.start_time.date().isoformat()
|
||
if date_key not in daily_stats:
|
||
daily_stats[date_key] = {
|
||
'date': date_key,
|
||
'total_time': 0,
|
||
'tasks': {}
|
||
}
|
||
|
||
task_id = record.task_id
|
||
if task_id not in daily_stats[date_key]['tasks']:
|
||
task = db.query(Task).get(task_id)
|
||
daily_stats[date_key]['tasks'][task_id] = {
|
||
'task_title': task.title if task else f'任务{task_id}',
|
||
'total_time': 0
|
||
}
|
||
|
||
duration = record.duration or 0
|
||
daily_stats[date_key]['total_time'] += duration
|
||
daily_stats[date_key]['tasks'][task_id]['total_time'] += duration
|
||
|
||
return {
|
||
'period': f'{start_date.isoformat()} 至 {end_date.isoformat()}',
|
||
'daily_stats': list(daily_stats.values())
|
||
}
|
||
|
||
# 时间段历史API
|
||
@api.get('/tasks/{task_id}/time-history')
|
||
async def get_task_time_history(
|
||
task_id: int,
|
||
days: int = 30,
|
||
page: int = 1,
|
||
per_page: int = 20,
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""获取任务的时间段历史"""
|
||
task = db.query(Task).filter(Task.id == task_id).first()
|
||
if not task:
|
||
raise HTTPException(status_code=404, detail='任务不存在')
|
||
|
||
# 计算日期范围
|
||
end_date = datetime.now().date()
|
||
start_date = end_date - timedelta(days=days-1)
|
||
|
||
start_datetime = datetime.combine(start_date, datetime.min.time())
|
||
end_datetime = datetime.combine(end_date, datetime.max.time())
|
||
|
||
# 查询时间记录
|
||
query = db.query(TimeRecord).filter(
|
||
TimeRecord.task_id == task_id,
|
||
TimeRecord.start_time >= start_datetime,
|
||
TimeRecord.start_time <= end_datetime,
|
||
TimeRecord.end_time.isnot(None)
|
||
).order_by(TimeRecord.start_time.desc())
|
||
|
||
# 手动分页
|
||
total = query.count()
|
||
offset = (page - 1) * per_page
|
||
records = query.offset(offset).limit(per_page).all()
|
||
|
||
# 按日期分组
|
||
daily_segments = {}
|
||
for record in records:
|
||
date_key = record.start_time.date().isoformat()
|
||
if date_key not in daily_segments:
|
||
daily_segments[date_key] = {
|
||
'date': date_key,
|
||
'total_duration': 0,
|
||
'segments': []
|
||
}
|
||
|
||
daily_segments[date_key]['total_duration'] += record.duration or 0
|
||
daily_segments[date_key]['segments'].append(record.to_dict())
|
||
|
||
pages = (total + per_page - 1) // per_page
|
||
|
||
return {
|
||
'task': task.to_dict(db),
|
||
'period': f'{start_date.isoformat()} 至 {end_date.isoformat()}',
|
||
'daily_segments': list(daily_segments.values()),
|
||
'pagination': {
|
||
'page': page,
|
||
'pages': pages,
|
||
'per_page': per_page,
|
||
'total': total,
|
||
'has_next': page < pages,
|
||
'has_prev': page > 1
|
||
}
|
||
}
|
||
|
||
@api.get('/time-history')
|
||
async def get_all_time_history(
|
||
days: int = 7,
|
||
task_id: Optional[int] = None,
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""获取所有任务的时间段历史"""
|
||
# 计算日期范围
|
||
end_date = datetime.now().date()
|
||
start_date = end_date - timedelta(days=days-1)
|
||
|
||
start_datetime = datetime.combine(start_date, datetime.min.time())
|
||
end_datetime = datetime.combine(end_date, datetime.max.time())
|
||
|
||
# 构建查询
|
||
query = db.query(TimeRecord).filter(
|
||
TimeRecord.start_time >= start_datetime,
|
||
TimeRecord.start_time <= end_datetime,
|
||
TimeRecord.end_time.isnot(None)
|
||
)
|
||
|
||
if task_id:
|
||
query = query.filter(TimeRecord.task_id == task_id)
|
||
|
||
# 按开始时间排序
|
||
records = query.order_by(TimeRecord.start_time.desc()).all()
|
||
|
||
# 按日期和任务分组
|
||
daily_tasks = {}
|
||
for record in records:
|
||
date_key = record.start_time.date().isoformat()
|
||
task_id_key = record.task_id
|
||
|
||
if date_key not in daily_tasks:
|
||
daily_tasks[date_key] = {}
|
||
|
||
if task_id_key not in daily_tasks[date_key]:
|
||
task = db.query(Task).get(task_id_key)
|
||
daily_tasks[date_key][task_id_key] = {
|
||
'task': task.to_dict(db) if task else None,
|
||
'total_duration': 0,
|
||
'segments': []
|
||
}
|
||
|
||
daily_tasks[date_key][task_id_key]['total_duration'] += record.duration or 0
|
||
daily_tasks[date_key][task_id_key]['segments'].append(record.to_dict())
|
||
|
||
# 转换为列表格式
|
||
result = []
|
||
for date, tasks in daily_tasks.items():
|
||
day_data = {
|
||
'date': date,
|
||
'total_time': sum(task['total_duration'] for task in tasks.values()),
|
||
'tasks': list(tasks.values())
|
||
}
|
||
result.append(day_data)
|
||
|
||
# 按日期排序(最新的在前)
|
||
result.sort(key=lambda x: x['date'], reverse=True)
|
||
|
||
return {
|
||
'period': f'{start_date.isoformat()} 至 {end_date.isoformat()}',
|
||
'daily_tasks': result
|
||
}
|