diff --git a/Dockerfile b/Dockerfile index bd7fe89..11b4e08 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,8 +6,7 @@ WORKDIR /app # 设置环境变量 ENV PYTHONDONTWRITEBYTECODE=1 \ - PYTHONUNBUFFERED=1 \ - FLASK_APP=app.py + PYTHONUNBUFFERED=1 # 安装系统依赖 RUN apt-get update && apt-get install -y --no-install-recommends \ diff --git a/backend/app.py b/backend/app.py index 464e63d..5dd0475 100644 --- a/backend/app.py +++ b/backend/app.py @@ -1,44 +1,56 @@ -from flask import Flask, send_from_directory -from flask_cors import CORS -from models import db, Task, TimeRecord +from fastapi import FastAPI +from fastapi.staticfiles import StaticFiles +from fastapi.responses import FileResponse +from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.sessions import SessionMiddleware +from database import init_db from routes import api import os def create_app(): - app = Flask(__name__) + app = FastAPI( + title="WorkList API", + description="任务管理和时间追踪API", + version="1.0.0" + ) - # 配置 - app.config['SECRET_KEY'] = 'your-secret-key-here-change-in-production' - app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///worklist.db' - app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False + # 配置 CORS + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 生产环境应该设置具体的域名 + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) - # Session配置 - app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' - app.config['SESSION_COOKIE_HTTPONLY'] = True - app.config['PERMANENT_SESSION_LIFETIME'] = 86400 # 24小时 + # 配置 Session + app.add_middleware( + SessionMiddleware, + secret_key='your-secret-key-here-change-in-production', + max_age=86400, # 24小时 + same_site='lax', + https_only=False # 生产环境应设置为True + ) - # 初始化扩展 - db.init_app(app) - CORS(app, supports_credentials=True) # 允许跨域请求并支持凭证 + # 初始化数据库 + init_db() + + # 注册路由 + app.include_router(api, prefix='/api') - # 注册蓝图 - app.register_blueprint(api, url_prefix='/api') - - # 创建数据库表 - with app.app_context(): - db.create_all() - # 静态文件服务(用于前端) - @app.route('/') - def index(): - return send_from_directory('../frontend', 'index.html') - - @app.route('/') - def static_files(filename): - return send_from_directory('../frontend', filename) - + frontend_path = os.path.join(os.path.dirname(__file__), '../frontend') + if os.path.exists(frontend_path): + @app.get("/") + async def index(): + return FileResponse(os.path.join(frontend_path, 'index.html')) + + app.mount("/", StaticFiles(directory=frontend_path, html=True), name="static") + return app +app = create_app() + if __name__ == '__main__': - app = create_app() - app.run(debug=True, host='0.0.0.0', port=5000) + import uvicorn + uvicorn.run("app:app", host='0.0.0.0', port=5000, reload=True) diff --git a/backend/database.py b/backend/database.py new file mode 100644 index 0000000..c84cd25 --- /dev/null +++ b/backend/database.py @@ -0,0 +1,29 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, Session +from models import Base +import os + +# 数据库配置 +DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///worklist.db') + +# 创建引擎 +engine = create_engine( + DATABASE_URL, + connect_args={"check_same_thread": False} if DATABASE_URL.startswith('sqlite') else {}, + echo=False +) + +# 创建会话工厂 +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +def init_db(): + """初始化数据库""" + Base.metadata.create_all(bind=engine) + +def get_db(): + """获取数据库会话的依赖函数""" + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/backend/models.py b/backend/models.py index 14b9ebf..db2b928 100644 --- a/backend/models.py +++ b/backend/models.py @@ -1,18 +1,19 @@ -from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, func +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship from datetime import datetime -from sqlalchemy import func from werkzeug.security import generate_password_hash, check_password_hash -db = SQLAlchemy() +Base = declarative_base() -class User(db.Model): +class User(Base): """用户模型""" __tablename__ = 'users' - id = db.Column(db.Integer, primary_key=True) - username = db.Column(db.String(80), unique=True, nullable=False) - password_hash = db.Column(db.String(200), nullable=False) - created_at = db.Column(db.DateTime, default=datetime.utcnow) + id = Column(Integer, primary_key=True) + username = Column(String(80), unique=True, nullable=False) + password_hash = Column(String(200), nullable=False) + created_at = Column(DateTime, default=datetime.utcnow) def set_password(self, password): """设置密码(哈希加密)""" @@ -29,23 +30,23 @@ class User(db.Model): 'created_at': self.created_at.isoformat() if self.created_at else None } -class Task(db.Model): +class Task(Base): """任务模型""" __tablename__ = 'tasks' - - id = db.Column(db.Integer, primary_key=True) - title = db.Column(db.String(200), nullable=False) - description = db.Column(db.Text) - polished_description = db.Column(db.Text) # AI润色后的描述 - status = db.Column(db.String(20), default='pending') # pending, in_progress, completed - created_at = db.Column(db.DateTime, default=datetime.utcnow) - updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - + + id = Column(Integer, primary_key=True) + title = Column(String(200), nullable=False) + description = Column(Text) + polished_description = Column(Text) # AI润色后的描述 + status = Column(String(20), default='pending') # pending, in_progress, completed + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + # 关联关系 - time_records = db.relationship('TimeRecord', backref='task', lazy=True, cascade='all, delete-orphan') - - def to_dict(self): - return { + time_records = relationship('TimeRecord', back_populates='task', cascade='all, delete-orphan') + + def to_dict(self, db_session=None): + result = { 'id': self.id, 'title': self.title, 'description': self.description, @@ -53,27 +54,34 @@ class Task(db.Model): 'status': self.status, 'created_at': self.created_at.isoformat() if self.created_at else None, 'updated_at': self.updated_at.isoformat() if self.updated_at else None, - 'total_time': self.get_total_time() } - - def get_total_time(self): + if db_session: + result['total_time'] = self.get_total_time(db_session) + else: + result['total_time'] = 0 + return result + + def get_total_time(self, db_session): """获取任务总时长(秒)""" - total_seconds = db.session.query(func.sum(TimeRecord.duration)).filter( + total_seconds = db_session.query(func.sum(TimeRecord.duration)).filter( TimeRecord.task_id == self.id ).scalar() or 0 return int(total_seconds) -class TimeRecord(db.Model): +class TimeRecord(Base): """时间记录模型""" __tablename__ = 'time_records' - - id = db.Column(db.Integer, primary_key=True) - task_id = db.Column(db.Integer, db.ForeignKey('tasks.id'), nullable=False) - start_time = db.Column(db.DateTime, nullable=False) - end_time = db.Column(db.DateTime) - duration = db.Column(db.Integer) # 时长(秒) - created_at = db.Column(db.DateTime, default=datetime.utcnow) - + + id = Column(Integer, primary_key=True) + task_id = Column(Integer, ForeignKey('tasks.id'), nullable=False) + start_time = Column(DateTime, nullable=False) + end_time = Column(DateTime) + duration = Column(Integer) # 时长(秒) + created_at = Column(DateTime, default=datetime.utcnow) + + # 关联关系 + task = relationship('Task', back_populates='time_records') + def to_dict(self): return { 'id': self.id, @@ -83,7 +91,7 @@ class TimeRecord(db.Model): 'duration': self.duration, 'created_at': self.created_at.isoformat() if self.created_at else None } - + def calculate_duration(self): """计算时长""" if self.start_time and self.end_time: diff --git a/backend/requirements.txt b/backend/requirements.txt index 8700260..36d829e 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,6 +1,7 @@ -Flask==2.3.3 -Flask-CORS==4.0.0 +fastapi==0.109.0 +uvicorn[standard]==0.27.0 SQLAlchemy==2.0.21 -Flask-SQLAlchemy==3.0.5 python-dotenv==1.0.0 openai==0.28.1 +itsdangerous==2.1.2 +starlette-session==0.4.1 diff --git a/backend/routes.py b/backend/routes.py index e17f4bd..12f2715 100644 --- a/backend/routes.py +++ b/backend/routes.py @@ -1,251 +1,258 @@ -from flask import Blueprint, request, jsonify, session +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy.orm import Session from datetime import datetime, timedelta -from models import db, Task, TimeRecord, User +from database import get_db +from models import Task, TimeRecord, User from ai_service import ai_service -import json +from pydantic import BaseModel +from typing import Optional, List +from sqlalchemy import func -api = Blueprint('api', __name__) +api = APIRouter() + +# Pydantic 模型用于请求验证 +class LoginRequest(BaseModel): + username: str + password: str + +class RegisterRequest(BaseModel): + username: str + password: str + +class TaskCreate(BaseModel): + title: str + description: Optional[str] = '' + status: Optional[str] = 'pending' + +class TaskUpdate(BaseModel): + title: Optional[str] = None + description: Optional[str] = None + status: Optional[str] = None + +class TimerRequest(BaseModel): + task_id: int + +class TimerBatchRequest(BaseModel): + task_ids: List[int] # 认证API -@api.route('/auth/login', methods=['POST']) -def login(): +@api.post('/auth/login') +async def login(request: Request, data: LoginRequest, db: Session = Depends(get_db)): """用户登录""" - data = request.get_json() + user = db.query(User).filter_by(username=data.username).first() - if not data or 'username' not in data or 'password' not in data: - return jsonify({'error': '用户名和密码不能为空'}), 400 - - username = data['username'] - password = data['password'] - - user = User.query.filter_by(username=username).first() - - if user and user.check_password(password): + if user and user.check_password(data.password): # 登录成功,设置session - session['user_id'] = user.id - session['username'] = user.username - return jsonify({ + request.session['user_id'] = user.id + request.session['username'] = user.username + return { 'message': '登录成功', 'user': user.to_dict() - }) + } else: - return jsonify({'error': '用户名或密码错误'}), 401 + raise HTTPException(status_code=401, detail='用户名或密码错误') -@api.route('/auth/logout', methods=['POST']) -def logout(): +@api.post('/auth/logout') +async def logout(request: Request): """用户登出""" - session.clear() - return jsonify({'message': '登出成功'}) + request.session.clear() + return {'message': '登出成功'} -@api.route('/auth/check', methods=['GET']) -def check_auth(): +@api.get('/auth/check') +async def check_auth(request: Request, db: Session = Depends(get_db)): """检查登录状态""" - if 'user_id' in session: - user = User.query.get(session['user_id']) + user_id = request.session.get('user_id') + if user_id: + user = db.query(User).get(user_id) if user: - return jsonify({ + return { 'authenticated': True, 'user': user.to_dict() - }) - return jsonify({'authenticated': False}), 401 + } + raise HTTPException(status_code=401, detail='未登录') -@api.route('/auth/register', methods=['POST']) -def register(): +@api.post('/auth/register', status_code=201) +async def register(data: RegisterRequest, db: Session = Depends(get_db)): """用户注册(可选,用于创建初始用户)""" - data = request.get_json() - - if not data or 'username' not in data or 'password' not in data: - return jsonify({'error': '用户名和密码不能为空'}), 400 - - username = data['username'] - password = data['password'] - # 检查用户是否已存在 - if User.query.filter_by(username=username).first(): - return jsonify({'error': '用户名已存在'}), 400 + if db.query(User).filter_by(username=data.username).first(): + raise HTTPException(status_code=400, detail='用户名已存在') # 创建新用户 - user = User(username=username) - user.set_password(password) + user = User(username=data.username) + user.set_password(data.password) - db.session.add(user) - db.session.commit() + db.add(user) + db.commit() + db.refresh(user) - return jsonify({ + return { 'message': '注册成功', 'user': user.to_dict() - }), 201 + } # 任务管理API -@api.route('/tasks', methods=['GET']) -def get_tasks(): +@api.get('/tasks') +async def get_tasks(db: Session = Depends(get_db)): """获取所有任务""" - tasks = Task.query.order_by(Task.created_at.desc()).all() - return jsonify([task.to_dict() for task in tasks]) + tasks = db.query(Task).order_by(Task.created_at.desc()).all() + return [task.to_dict(db) for task in tasks] -@api.route('/tasks', methods=['POST']) -def create_task(): +@api.post('/tasks', status_code=201) +async def create_task(data: TaskCreate, db: Session = Depends(get_db)): """创建新任务""" - data = request.get_json() - - if not data or 'title' not in data: - return jsonify({'error': '任务标题不能为空'}), 400 - task = Task( - title=data['title'], - description=data.get('description', ''), - status=data.get('status', 'pending') + title=data.title, + description=data.description, + status=data.status ) - - db.session.add(task) - db.session.commit() - - return jsonify(task.to_dict()), 201 -@api.route('/tasks/', methods=['PUT']) -def update_task(task_id): + db.add(task) + db.commit() + db.refresh(task) + + return task.to_dict(db) + +@api.put('/tasks/{task_id}') +async def update_task(task_id: int, data: TaskUpdate, db: Session = Depends(get_db)): """更新任务""" - task = Task.query.get_or_404(task_id) - data = request.get_json() - - if 'title' in data: - task.title = data['title'] - if 'description' in data: - task.description = data['description'] - if 'status' in data: - task.status = data['status'] - + task = db.query(Task).filter(Task.id == task_id).first() + if not task: + raise HTTPException(status_code=404, detail='任务不存在') + + if data.title is not None: + task.title = data.title + if data.description is not None: + task.description = data.description + if data.status is not None: + task.status = data.status + task.updated_at = datetime.utcnow() - db.session.commit() - - return jsonify(task.to_dict()) + db.commit() + db.refresh(task) -@api.route('/tasks/', methods=['DELETE']) -def delete_task(task_id): + return task.to_dict(db) + +@api.delete('/tasks/{task_id}') +async def delete_task(task_id: int, db: Session = Depends(get_db)): """删除任务""" - task = Task.query.get_or_404(task_id) - db.session.delete(task) - db.session.commit() - - return jsonify({'message': '任务删除成功'}) + task = db.query(Task).filter(Task.id == task_id).first() + if not task: + raise HTTPException(status_code=404, detail='任务不存在') -@api.route('/tasks//polish', methods=['POST']) -def polish_task_description(task_id): + db.delete(task) + db.commit() + + return {'message': '任务删除成功'} + +@api.post('/tasks/{task_id}/polish') +async def polish_task_description(task_id: int, db: Session = Depends(get_db)): """AI润色任务描述""" - task = Task.query.get_or_404(task_id) - + task = db.query(Task).filter(Task.id == task_id).first() + if not task: + raise HTTPException(status_code=404, detail='任务不存在') + if not ai_service.is_available(): - return jsonify({'error': 'AI服务不可用,请检查API密钥配置'}), 500 - + raise HTTPException(status_code=500, detail='AI服务不可用,请检查API密钥配置') + if not task.description: - return jsonify({'error': '任务描述为空,无法润色'}), 400 - + raise HTTPException(status_code=400, detail='任务描述为空,无法润色') + polished_description = ai_service.polish_description(task.description) task.polished_description = polished_description task.updated_at = datetime.utcnow() - - db.session.commit() - - return jsonify({ + + db.commit() + + return { 'original': task.description, 'polished': polished_description - }) + } # 计时器API -@api.route('/timer/start', methods=['POST']) -def start_timer(): +@api.post('/timer/start') +async def start_timer(data: TimerRequest, db: Session = Depends(get_db)): """开始计时""" - data = request.get_json() - task_id = data.get('task_id') - - if not task_id: - return jsonify({'error': '任务ID不能为空'}), 400 - - task = Task.query.get_or_404(task_id) - + task = db.query(Task).filter(Task.id == data.task_id).first() + if not task: + raise HTTPException(status_code=404, detail='任务不存在') + # 检查是否已有进行中的计时 - active_record = TimeRecord.query.filter_by( - task_id=task_id, + active_record = db.query(TimeRecord).filter_by( + task_id=data.task_id, end_time=None ).first() - + if active_record: - return jsonify({'error': '该任务已在计时中'}), 400 - + raise HTTPException(status_code=400, detail='该任务已在计时中') + # 创建新的时间记录 time_record = TimeRecord( - task_id=task_id, + task_id=data.task_id, start_time=datetime.utcnow() ) - - db.session.add(time_record) - task.status = 'in_progress' - db.session.commit() - - return jsonify(time_record.to_dict()) -@api.route('/timer/stop', methods=['POST']) -def stop_timer(): + db.add(time_record) + task.status = 'in_progress' + db.commit() + db.refresh(time_record) + + return time_record.to_dict() + +@api.post('/timer/stop') +async def stop_timer(data: TimerRequest, db: Session = Depends(get_db)): """停止计时""" - data = request.get_json() - task_id = data.get('task_id') - - if not task_id: - return jsonify({'error': '任务ID不能为空'}), 400 - # 查找进行中的时间记录 - time_record = TimeRecord.query.filter_by( - task_id=task_id, + time_record = db.query(TimeRecord).filter_by( + task_id=data.task_id, end_time=None ).first() - + if not time_record: - return jsonify({'error': '没有找到进行中的计时'}), 400 - + raise HTTPException(status_code=400, detail='没有找到进行中的计时') + # 结束计时 time_record.end_time = datetime.utcnow() time_record.calculate_duration() - + # 更新任务状态 - task = Task.query.get(task_id) + task = db.query(Task).get(data.task_id) if task: task.status = 'pending' # 或者根据业务逻辑设置其他状态 - - db.session.commit() - - return jsonify(time_record.to_dict()) -@api.route('/timer/status/', methods=['GET']) -def get_timer_status(task_id): + db.commit() + db.refresh(time_record) + + return time_record.to_dict() + +@api.get('/timer/status/{task_id}') +async def get_timer_status(task_id: int, db: Session = Depends(get_db)): """获取任务计时状态""" - active_record = TimeRecord.query.filter_by( - task_id=task_id, + active_record = db.query(TimeRecord).filter_by( + task_id=task_id, end_time=None ).first() - + if active_record: - return jsonify({ + return { 'is_running': True, 'start_time': active_record.start_time.isoformat(), 'duration': int((datetime.utcnow() - active_record.start_time).total_seconds()) - }) + } else: - return jsonify({'is_running': False}) + return {'is_running': False} -@api.route('/timer/status/batch', methods=['POST']) -def get_timer_status_batch(): +@api.post('/timer/status/batch') +async def get_timer_status_batch(data: TimerBatchRequest, db: Session = Depends(get_db)): """批量获取任务计时状态""" - data = request.get_json() or {} - task_ids = data.get('task_ids', []) - - if not isinstance(task_ids, list) or not task_ids: - return jsonify({'error': '任务ID列表不能为空'}), 400 + if not data.task_ids: + raise HTTPException(status_code=400, detail='任务ID列表不能为空') # 初始化默认状态 - statuses = {int(task_id): {'is_running': False} for task_id in task_ids} + statuses = {int(task_id): {'is_running': False} for task_id in data.task_ids} - active_records = TimeRecord.query.filter( - TimeRecord.task_id.in_(task_ids), + active_records = db.query(TimeRecord).filter( + TimeRecord.task_id.in_(data.task_ids), TimeRecord.end_time.is_(None) ).all() @@ -257,72 +264,70 @@ def get_timer_status_batch(): 'duration': int((now - record.start_time).total_seconds()) } - return jsonify(statuses) + return statuses + # 统计报表API -@api.route('/reports/daily', methods=['GET']) -def get_daily_report(): +@api.get('/reports/daily') +async def get_daily_report(date: Optional[str] = None, db: Session = Depends(get_db)): """获取日报表""" - date_str = request.args.get('date') - if date_str: + if date: try: - target_date = datetime.strptime(date_str, '%Y-%m-%d').date() + target_date = datetime.strptime(date, '%Y-%m-%d').date() except ValueError: - return jsonify({'error': '日期格式错误,请使用YYYY-MM-DD格式'}), 400 + raise HTTPException(status_code=400, detail='日期格式错误,请使用YYYY-MM-DD格式') else: target_date = datetime.now().date() - + # 获取指定日期的所有时间记录 start_datetime = datetime.combine(target_date, datetime.min.time()) end_datetime = datetime.combine(target_date, datetime.max.time()) - - records = TimeRecord.query.filter( + + records = db.query(TimeRecord).filter( TimeRecord.start_time >= start_datetime, TimeRecord.start_time <= end_datetime, TimeRecord.end_time.isnot(None) ).all() - + # 按任务分组统计 task_stats = {} total_time = 0 - + for record in records: task_id = record.task_id if task_id not in task_stats: - task = Task.query.get(task_id) + task = db.query(Task).get(task_id) task_stats[task_id] = { - 'task': task.to_dict() if task else None, + 'task': task.to_dict(db) if task else None, 'total_duration': 0, 'records': [] } - + task_stats[task_id]['total_duration'] += record.duration or 0 task_stats[task_id]['records'].append(record.to_dict()) total_time += record.duration or 0 - - return jsonify({ + + return { 'date': target_date.isoformat(), 'total_time': total_time, 'task_stats': list(task_stats.values()) - }) + } -@api.route('/reports/summary', methods=['GET']) -def get_summary_report(): +@api.get('/reports/summary') +async def get_summary_report(days: int = 7, db: Session = Depends(get_db)): """获取汇总报表""" - days = int(request.args.get('days', 7)) # 默认最近7天 - end_date = datetime.now().date() start_date = end_date - timedelta(days=days-1) - + start_datetime = datetime.combine(start_date, datetime.min.time()) end_datetime = datetime.combine(end_date, datetime.max.time()) - + # 获取时间范围内的记录 - records = TimeRecord.query.filter( + records = db.query(TimeRecord).filter( TimeRecord.start_time >= start_datetime, TimeRecord.start_time <= end_datetime, TimeRecord.end_time.isnot(None) ).all() - + # 按日期分组 daily_stats = {} for record in records: @@ -333,60 +338,61 @@ def get_summary_report(): 'total_time': 0, 'tasks': {} } - + task_id = record.task_id if task_id not in daily_stats[date_key]['tasks']: - task = Task.query.get(task_id) + task = db.query(Task).get(task_id) daily_stats[date_key]['tasks'][task_id] = { 'task_title': task.title if task else f'任务{task_id}', 'total_time': 0 } - + duration = record.duration or 0 daily_stats[date_key]['total_time'] += duration daily_stats[date_key]['tasks'][task_id]['total_time'] += duration - - return jsonify({ + + return { 'period': f'{start_date.isoformat()} 至 {end_date.isoformat()}', 'daily_stats': list(daily_stats.values()) - }) + } # 时间段历史API -@api.route('/tasks//time-history', methods=['GET']) -def get_task_time_history(task_id): +@api.get('/tasks/{task_id}/time-history') +async def get_task_time_history( + task_id: int, + days: int = 30, + page: int = 1, + per_page: int = 20, + db: Session = Depends(get_db) +): """获取任务的时间段历史""" - task = Task.query.get_or_404(task_id) - - # 获取参数 - days = int(request.args.get('days', 30)) # 默认最近30天 - page = int(request.args.get('page', 1)) - per_page = int(request.args.get('per_page', 20)) - + task = db.query(Task).filter(Task.id == task_id).first() + if not task: + raise HTTPException(status_code=404, detail='任务不存在') + # 计算日期范围 end_date = datetime.now().date() start_date = end_date - timedelta(days=days-1) - + start_datetime = datetime.combine(start_date, datetime.min.time()) end_datetime = datetime.combine(end_date, datetime.max.time()) - + # 查询时间记录 - query = TimeRecord.query.filter( + query = db.query(TimeRecord).filter( TimeRecord.task_id == task_id, TimeRecord.start_time >= start_datetime, TimeRecord.start_time <= end_datetime, TimeRecord.end_time.isnot(None) ).order_by(TimeRecord.start_time.desc()) - - # 分页 - pagination = query.paginate( - page=page, - per_page=per_page, - error_out=False - ) - + + # 手动分页 + total = query.count() + offset = (page - 1) * per_page + records = query.offset(offset).limit(per_page).all() + # 按日期分组 daily_segments = {} - for record in pagination.items: + for record in records: date_key = record.start_time.date().isoformat() if date_key not in daily_segments: daily_segments[date_key] = { @@ -394,71 +400,73 @@ def get_task_time_history(task_id): 'total_duration': 0, 'segments': [] } - + daily_segments[date_key]['total_duration'] += record.duration or 0 daily_segments[date_key]['segments'].append(record.to_dict()) - - return jsonify({ - 'task': task.to_dict(), + + pages = (total + per_page - 1) // per_page + + return { + 'task': task.to_dict(db), 'period': f'{start_date.isoformat()} 至 {end_date.isoformat()}', 'daily_segments': list(daily_segments.values()), 'pagination': { - 'page': pagination.page, - 'pages': pagination.pages, - 'per_page': pagination.per_page, - 'total': pagination.total, - 'has_next': pagination.has_next, - 'has_prev': pagination.has_prev + 'page': page, + 'pages': pages, + 'per_page': per_page, + 'total': total, + 'has_next': page < pages, + 'has_prev': page > 1 } - }) + } -@api.route('/time-history', methods=['GET']) -def get_all_time_history(): +@api.get('/time-history') +async def get_all_time_history( + days: int = 7, + task_id: Optional[int] = None, + db: Session = Depends(get_db) +): """获取所有任务的时间段历史""" - # 获取参数 - days = int(request.args.get('days', 7)) # 默认最近7天 - task_id = request.args.get('task_id') # 可选的任务ID过滤 - # 计算日期范围 end_date = datetime.now().date() start_date = end_date - timedelta(days=days-1) - + start_datetime = datetime.combine(start_date, datetime.min.time()) end_datetime = datetime.combine(end_date, datetime.max.time()) - + # 构建查询 - query = TimeRecord.query.filter( + query = db.query(TimeRecord).filter( TimeRecord.start_time >= start_datetime, TimeRecord.start_time <= end_datetime, TimeRecord.end_time.isnot(None) ) - + if task_id: query = query.filter(TimeRecord.task_id == task_id) - + # 按开始时间排序 records = query.order_by(TimeRecord.start_time.desc()).all() - + # 按日期和任务分组 daily_tasks = {} for record in records: date_key = record.start_time.date().isoformat() - task_id = record.task_id - + task_id_key = record.task_id + if date_key not in daily_tasks: daily_tasks[date_key] = {} - - if task_id not in daily_tasks[date_key]: - task = Task.query.get(task_id) - daily_tasks[date_key][task_id] = { - 'task': task.to_dict() if task else None, + + if task_id_key not in daily_tasks[date_key]: + task = db.query(Task).get(task_id_key) + daily_tasks[date_key][task_id_key] = { + 'task': task.to_dict(db) if task else None, 'total_duration': 0, 'segments': [] } - - daily_tasks[date_key][task_id]['total_duration'] += record.duration or 0 - daily_tasks[date_key][task_id]['segments'].append(record.to_dict()) - + + daily_tasks[date_key][task_id_key]['total_duration'] += record.duration or 0 + daily_tasks[date_key][task_id_key]['segments'].append(record.to_dict()) + # 转换为列表格式 result = [] for date, tasks in daily_tasks.items(): @@ -468,11 +476,11 @@ def get_all_time_history(): 'tasks': list(tasks.values()) } result.append(day_data) - + # 按日期排序(最新的在前) result.sort(key=lambda x: x['date'], reverse=True) - - return jsonify({ + + return { 'period': f'{start_date.isoformat()} 至 {end_date.isoformat()}', 'daily_tasks': result - }) + } diff --git a/create_user.py b/create_user.py index b6050e9..6e6c0ba 100644 --- a/create_user.py +++ b/create_user.py @@ -5,16 +5,26 @@ 用于创建管理员账户 """ -from backend.app import create_app -from backend.models import db, User +import sys +import os + +# 将backend目录添加到Python路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'backend')) + +from database import SessionLocal, init_db +from models import User def create_initial_user(): """创建初始用户""" - app = create_app() + # 初始化数据库 + init_db() - with app.app_context(): + # 创建会话 + db = SessionLocal() + + try: # 检查是否已有用户 - existing_user = User.query.first() + existing_user = db.query(User).first() if existing_user: print(f"用户已存在: {existing_user.username}") @@ -27,12 +37,14 @@ def create_initial_user(): user = User(username=username) user.set_password(password) - db.session.add(user) - db.session.commit() + db.add(user) + db.commit() print(f"用户创建成功!") print(f"用户名: {username}") print(f"请妥善保管密码!") + finally: + db.close() if __name__ == '__main__': create_initial_user() diff --git a/docker-compose.yml b/docker-compose.yml index 3d50f15..f378b3e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,7 @@ version: '3.8' services: - # Flask应用服务 + # FastAPI应用服务 app: build: context: . @@ -11,8 +11,6 @@ services: ports: - "5001:5000" environment: - - FLASK_APP=app.py - - FLASK_ENV=production - PYTHONUNBUFFERED=1 - SECRET_KEY=${SECRET_KEY:-your-secret-key-here} - OPENAI_API_KEY=${OPENAI_API_KEY:-} diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh index b673c04..d976dc4 100644 --- a/docker-entrypoint.sh +++ b/docker-entrypoint.sh @@ -9,17 +9,19 @@ cd /app/backend # 等待数据库初始化 echo "初始化数据库..." python -c " -from app import create_app -from models import db, User +from database import init_db, SessionLocal +from models import User import os -app = create_app() -with app.app_context(): - # 创建所有表 - db.create_all() +# 初始化数据库 +init_db() +# 创建会话 +db = SessionLocal() + +try: # 检查是否已有用户 - existing_user = User.query.first() + existing_user = db.query(User).first() if not existing_user: # 从环境变量获取默认用户信息 @@ -29,16 +31,18 @@ with app.app_context(): # 创建默认用户 user = User(username=default_username) user.set_password(default_password) - db.session.add(user) - db.session.commit() + db.add(user) + db.commit() print(f'已创建默认用户: {default_username}') print(f'默认密码: {default_password}') print('请登录后立即修改密码!') else: print('用户已存在,跳过初始化') +finally: + db.close() " -echo "启动Flask应用..." -# 启动应用 -exec python app.py +echo "启动FastAPI应用..." +# 启动应用 (使用uvicorn) +exec uvicorn app:app --host 0.0.0.0 --port 5000 diff --git a/start.py b/start.py index f65c511..8799192 100644 --- a/start.py +++ b/start.py @@ -38,7 +38,7 @@ OPENAI_API_KEY=your_openai_api_key_here # 数据库配置 DATABASE_URL=sqlite:///worklist.db -# Flask配置 +# FastAPI配置 SECRET_KEY=your-secret-key-here """) print("✓ 环境变量文件已创建: backend/.env") @@ -49,8 +49,8 @@ def start_server(): print("正在启动服务器...") os.chdir("backend") try: - # 启动Flask应用 - subprocess.run([sys.executable, "app.py"]) + # 启动FastAPI应用 (使用uvicorn) + subprocess.run([sys.executable, "-m", "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "5000", "--reload"]) except KeyboardInterrupt: print("\n服务器已停止") except Exception as e: