918 lines
29 KiB
Python
918 lines
29 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
基于 mcp_docx.py 封装的 MCP 服务器。
|
||
|
||
暴露主要 MCP 工具:
|
||
- list_docx_images:列出 DOCX 中的图片信息
|
||
- edit_docx: 进行文本替换 / 关键字上色(与 HTTP POST /edit_docx 能力一致)
|
||
|
||
额外提供 HTTP 文件接口(仅在 http 模式下可用):
|
||
- POST /upload: 上传文件到服务器
|
||
- GET /files/{filename}: 下载服务器上的文件
|
||
|
||
当前推荐的传输方式:
|
||
- stdio(本地调试)
|
||
- streamable-http(远程 HTTP,路径固定为 /mcp,推荐)
|
||
|
||
用法:
|
||
# 本地 stdio 模式(默认)
|
||
python mcp_docx_server.py --transport stdio
|
||
|
||
# HTTP 远程模式(推荐,默认 0.0.0.0:8080,对外暴露 /mcp)
|
||
python mcp_docx_server.py --transport http
|
||
python mcp_docx_server.py --transport http --host 0.0.0.0 --port 8080
|
||
|
||
# 客户端连接地址(http 模式):
|
||
# MCP 端点: http://<host>:<port>/mcp
|
||
|
||
注意:底层仍然完全复用 mcp_docx.py 中的逻辑,只是通过 MCP SDK 对外提供。
|
||
"""
|
||
|
||
import argparse
|
||
import hashlib
|
||
import json
|
||
import os
|
||
import shutil
|
||
import tempfile
|
||
import time
|
||
import urllib.parse
|
||
import zipfile
|
||
from contextlib import contextmanager
|
||
from datetime import datetime, date, timedelta
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
import requests
|
||
from lxml import etree
|
||
from mcp.server.fastmcp import FastMCP
|
||
from mcp.server.transport_security import TransportSecuritySettings
|
||
|
||
from mcp_docx import (
|
||
W,
|
||
_normalize_newlines,
|
||
get_images_info,
|
||
process,
|
||
_parse_span_replacement,
|
||
paragraph_replace,
|
||
unpack,
|
||
pack,
|
||
)
|
||
# HTTP 远程模式:添加文件上传下载路由
|
||
from starlette.responses import FileResponse, JSONResponse
|
||
from starlette.background import BackgroundTask
|
||
from starlette.requests import Request
|
||
from starlette.responses import FileResponse, JSONResponse
|
||
|
||
if os.name == "nt":
|
||
import msvcrt
|
||
else:
|
||
import fcntl
|
||
|
||
_disable_dns_rebinding = True
|
||
|
||
if _disable_dns_rebinding:
|
||
# 参考 python-sdk 官方文档:关闭 DNS rebinding 防护(适合本地或已由外层网关做安全控制的环境)
|
||
# https://github.com/modelcontextprotocol/python-sdk/issues/1798
|
||
transport_security = TransportSecuritySettings(
|
||
enable_dns_rebinding_protection=False,
|
||
)
|
||
else:
|
||
# 默认:开启 DNS rebinding 防护,但允许本机访问
|
||
# 如需通过网关 / 域名访问,可在这里追加 allowed_hosts / allowed_origins
|
||
transport_security = TransportSecuritySettings(
|
||
enable_dns_rebinding_protection=True,
|
||
allowed_hosts=["localhost:*", "127.0.0.1:*", "192.168.1.13:*","10.150.172.13:*"],
|
||
allowed_origins=["http://localhost:*", "http://127.0.0.1:*","http://192.168.1.13:*","http://10.150.172.13:*"],
|
||
)
|
||
|
||
|
||
mcp = FastMCP(
|
||
"docx-editor",
|
||
transport_security=transport_security,
|
||
)
|
||
|
||
# 全局变量:存储服务器配置
|
||
_server_config = {
|
||
"host": None,
|
||
"port": None,
|
||
"transport": None,
|
||
}
|
||
|
||
|
||
def _normalize_report_type(report_type: Optional[str]) -> Optional[str]:
|
||
if not report_type:
|
||
return None
|
||
t = str(report_type).strip().lower()
|
||
mapping = {
|
||
"日报": "daily",
|
||
"日報": "daily",
|
||
"daily": "daily",
|
||
"d": "daily",
|
||
"周报": "weekly",
|
||
"週報": "weekly",
|
||
"weekly": "weekly",
|
||
"w": "weekly",
|
||
"月报": "monthly",
|
||
"月報": "monthly",
|
||
"monthly": "monthly",
|
||
"m": "monthly",
|
||
}
|
||
return mapping.get(report_type, mapping.get(t))
|
||
|
||
|
||
def _build_issue_text(norm_type: Optional[str], now: datetime) -> str:
|
||
"""根据报告类型和生成时间计算“期数 + 日期”字符串。"""
|
||
d = now.date()
|
||
date_str = f"{d.year}年{d.month}月{d.day}日"
|
||
|
||
if norm_type == "daily":
|
||
# 日报:只有日期,没有期数
|
||
return date_str
|
||
|
||
if norm_type == "weekly":
|
||
# 周报:根据“当周周一所在月份”的周序号来计算期数
|
||
monday = d - timedelta(days=d.weekday())
|
||
year = monday.year
|
||
month = monday.month
|
||
|
||
first_day = date(year, month, 1)
|
||
offset = (0 - first_day.weekday()) % 7 # 距离第一个周一的天数
|
||
first_monday = first_day + timedelta(days=offset)
|
||
issue_no = ((monday - first_monday).days // 7) + 1
|
||
if issue_no < 1:
|
||
issue_no = 1
|
||
return f"{date_str}(第{issue_no}期)"
|
||
|
||
# 默认:月报逻辑,期数固定为第一期
|
||
return f"{date_str}(第1期)"
|
||
|
||
|
||
def _apply_report_date_logic_to_docx(
|
||
docx_path: str,
|
||
report_type: Optional[str],
|
||
report_title_time: Optional[str],
|
||
) -> None:
|
||
"""
|
||
只在“目录”之前的内容中,按照规则替换日期相关文本:
|
||
- 匹配第一个形如 YYYY年M月 的片段 → 替换为 report_title_time
|
||
- 匹配第一个形如 YYYY年M月D日(第X期) 的片段 →
|
||
按报告类型 + 当前生成时间计算期数和日期,并进行替换。
|
||
"""
|
||
norm_type = _normalize_report_type(report_type)
|
||
if not norm_type and not report_title_time:
|
||
return
|
||
|
||
# 没有任何需要替换的目标,直接返回
|
||
if not os.path.exists(docx_path):
|
||
return
|
||
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
unpack(docx_path, tmpdir)
|
||
doc_xml_path = os.path.join(tmpdir, "word", "document.xml")
|
||
if not os.path.exists(doc_xml_path):
|
||
return
|
||
|
||
tree = etree.parse(doc_xml_path)
|
||
root = tree.getroot()
|
||
|
||
title_replaced = False
|
||
issue_replaced = False
|
||
now = datetime.now()
|
||
|
||
# 正则模式
|
||
import re
|
||
|
||
pattern_title = re.compile(r"(\d{4})年(\d{1,2})月")
|
||
pattern_issue = re.compile(r"(\d{4})年(\d{1,2})月(\d{1,2})日(第(\d+)期)")
|
||
|
||
for p in root.iter(f"{{{W}}}p"):
|
||
# 聚合段落文本
|
||
t_nodes = list(p.iter(f"{{{W}}}t"))
|
||
full = "".join(t.text or "" for t in t_nodes)
|
||
|
||
# 遇到“目录”后就停止处理后面的内容
|
||
if "目录" in full:
|
||
break
|
||
|
||
para_repls = []
|
||
|
||
if report_title_time and not title_replaced:
|
||
m = pattern_title.search(full)
|
||
if m:
|
||
old = m.group(0)
|
||
new = report_title_time
|
||
para_repls.append((old, new))
|
||
title_replaced = True
|
||
|
||
if norm_type and not issue_replaced:
|
||
m2 = pattern_issue.search(full)
|
||
if m2:
|
||
old2 = m2.group(0)
|
||
new2 = _build_issue_text(norm_type, now)
|
||
para_repls.append((old2, new2))
|
||
issue_replaced = True
|
||
|
||
if para_repls:
|
||
paragraph_replace(p, para_repls)
|
||
|
||
tree.write(doc_xml_path, xml_declaration=True, encoding="UTF-8", standalone=True)
|
||
# 重新打包覆盖原始 DOCX
|
||
pack(tmpdir, docx_path, docx_path)
|
||
|
||
|
||
def _is_url(path: str) -> bool:
|
||
"""简单判断一个字符串是否为 HTTP/HTTPS URL。"""
|
||
return path.startswith("http://") or path.startswith("https://")
|
||
|
||
|
||
def _download_to_temp(url: str, suffix: str = ".tmp") -> str:
|
||
"""
|
||
将远程 URL 下载到临时文件,返回本地临时路径。
|
||
|
||
调用方负责在使用完毕后删除该文件。
|
||
"""
|
||
resp = requests.get(url, stream=True, timeout=30)
|
||
resp.raise_for_status()
|
||
|
||
fd, tmp_path = tempfile.mkstemp(suffix=suffix)
|
||
try:
|
||
with os.fdopen(fd, "wb") as f:
|
||
for chunk in resp.iter_content(chunk_size=8192):
|
||
if chunk:
|
||
f.write(chunk)
|
||
except Exception:
|
||
# 出错时清理临时文件
|
||
try:
|
||
os.remove(tmp_path)
|
||
except OSError:
|
||
pass
|
||
raise
|
||
|
||
return tmp_path
|
||
|
||
|
||
def _safe_filename(filename: Optional[str], default: str = "uploaded.docx") -> str:
|
||
"""提取安全文件名,避免路径穿越。"""
|
||
if not filename:
|
||
return default
|
||
decoded = urllib.parse.unquote(str(filename))
|
||
safe_name = os.path.basename(decoded).strip()
|
||
return safe_name or default
|
||
|
||
|
||
def _filename_from_url(url: str, default: str = "uploaded.docx") -> str:
|
||
"""从 URL 中推断文件名,优先读取 query 参数中的 filename。"""
|
||
parsed = urllib.parse.urlparse(url)
|
||
query = urllib.parse.parse_qs(parsed.query)
|
||
|
||
for key in ("filename", "fileName", "name"):
|
||
values = query.get(key)
|
||
if values:
|
||
return _safe_filename(values[0], default=default)
|
||
|
||
return _safe_filename(os.path.basename(parsed.path), default=default)
|
||
|
||
|
||
def _download_to_path(url: str, local_path: str) -> None:
|
||
"""将远程 URL 下载到指定路径,完成后原子覆盖目标文件。"""
|
||
resp = requests.get(url, stream=True, timeout=30)
|
||
resp.raise_for_status()
|
||
|
||
parent_dir = os.path.dirname(local_path) or "."
|
||
os.makedirs(parent_dir, exist_ok=True)
|
||
fd, tmp_path = tempfile.mkstemp(
|
||
suffix=os.path.splitext(local_path)[1] or ".tmp",
|
||
dir=parent_dir,
|
||
)
|
||
try:
|
||
with os.fdopen(fd, "wb") as f:
|
||
for chunk in resp.iter_content(chunk_size=8192):
|
||
if chunk:
|
||
f.write(chunk)
|
||
os.replace(tmp_path, local_path)
|
||
except Exception:
|
||
try:
|
||
os.remove(tmp_path)
|
||
except OSError:
|
||
pass
|
||
raise
|
||
|
||
|
||
def _build_output_url(abs_output_path: str) -> Optional[str]:
|
||
"""
|
||
构造输出文件的下载 URL。
|
||
|
||
优先使用环境变量 MCP_OUTPUT_BASE_URL,否则根据服务器配置自动构建。
|
||
|
||
约定:
|
||
- 如果设置了 MCP_OUTPUT_BASE_URL: 使用该 URL 作为基础
|
||
- 否则在 http 模式下: http://host:port/download/{filename}
|
||
- stdio 模式下: 返回 None
|
||
"""
|
||
filename = os.path.basename(abs_output_path)
|
||
encoded_filename = urllib.parse.quote(filename)
|
||
|
||
def _append_filename(base_url: str) -> str:
|
||
parsed = urllib.parse.urlparse(base_url)
|
||
query = urllib.parse.parse_qsl(parsed.query, keep_blank_values=True)
|
||
|
||
for index, (key, _) in enumerate(query):
|
||
if key in ("filename", "fileName", "name"):
|
||
query[index] = (key, filename)
|
||
return urllib.parse.urlunparse(
|
||
parsed._replace(query=urllib.parse.urlencode(query))
|
||
)
|
||
|
||
if parsed.path.rstrip("/").endswith("/download"):
|
||
query.append(("filename", filename))
|
||
return urllib.parse.urlunparse(
|
||
parsed._replace(query=urllib.parse.urlencode(query))
|
||
)
|
||
|
||
return base_url.rstrip("/") + "/" + encoded_filename
|
||
|
||
# 优先使用环境变量
|
||
base = os.getenv("MCP_OUTPUT_BASE_URL")
|
||
if base:
|
||
return _append_filename(base)
|
||
|
||
# 如果是 http 模式,自动构建下载 URL
|
||
if _server_config["transport"] == "http":
|
||
host = _server_config["host"]
|
||
port = _server_config["port"]
|
||
# 如果 host 是 0.0.0.0,尝试使用更具体的地址
|
||
if host == "0.0.0.0":
|
||
# 优先使用环境变量指定的公网地址
|
||
public_host = os.getenv("MCP_PUBLIC_HOST")
|
||
if public_host:
|
||
host = public_host
|
||
else:
|
||
# 默认使用 localhost
|
||
host = "localhost"
|
||
|
||
return _append_filename(f"http://{host}:{port}/download")
|
||
|
||
return None
|
||
|
||
|
||
def _get_upload_dir() -> str:
|
||
"""
|
||
获取文件上传目录。
|
||
|
||
优先使用环境变量 MCP_UPLOAD_DIR,否则使用当前目录下的 uploads 文件夹。
|
||
"""
|
||
upload_dir = os.getenv("MCP_UPLOAD_DIR", "./uploads")
|
||
os.makedirs(upload_dir, exist_ok=True)
|
||
return os.path.abspath(upload_dir)
|
||
|
||
|
||
def _get_tmp_upload_dir() -> str:
|
||
"""
|
||
获取临时上传目录。
|
||
|
||
优先使用环境变量 MCP_TMP_UPLOAD_DIR,否则使用当前目录下的 tmp 文件夹。
|
||
"""
|
||
tmp_dir = os.getenv("MCP_TMP_UPLOAD_DIR", "./tmp")
|
||
os.makedirs(tmp_dir, exist_ok=True)
|
||
return os.path.abspath(tmp_dir)
|
||
|
||
|
||
def _get_lock_dir() -> str:
|
||
"""获取文件锁目录。"""
|
||
lock_dir = os.path.join(_get_upload_dir(), ".locks")
|
||
os.makedirs(lock_dir, exist_ok=True)
|
||
return lock_dir
|
||
|
||
|
||
def _get_lock_path(target_path: str) -> str:
|
||
"""根据目标文件路径生成稳定的锁文件路径。"""
|
||
abs_target = os.path.abspath(target_path)
|
||
base_name = _safe_filename(os.path.basename(abs_target), default="file")
|
||
digest = hashlib.sha256(abs_target.encode("utf-8")).hexdigest()
|
||
return os.path.join(_get_lock_dir(), f"{base_name}.{digest}.lock")
|
||
|
||
|
||
def _acquire_lock(handle) -> None:
|
||
"""跨进程独占锁。"""
|
||
handle.seek(0, os.SEEK_END)
|
||
if handle.tell() == 0:
|
||
handle.write(b"0")
|
||
handle.flush()
|
||
handle.seek(0)
|
||
|
||
if os.name == "nt":
|
||
while True:
|
||
try:
|
||
msvcrt.locking(handle.fileno(), msvcrt.LK_LOCK, 1)
|
||
break
|
||
except OSError:
|
||
time.sleep(0.05)
|
||
else:
|
||
fcntl.flock(handle.fileno(), fcntl.LOCK_EX)
|
||
|
||
|
||
def _release_lock(handle) -> None:
|
||
"""释放跨进程独占锁。"""
|
||
handle.seek(0)
|
||
if os.name == "nt":
|
||
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1)
|
||
else:
|
||
fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
|
||
|
||
|
||
@contextmanager
|
||
def _file_lock(target_path: str):
|
||
"""针对目标文件路径获取独占文件锁。"""
|
||
lock_path = _get_lock_path(target_path)
|
||
with open(lock_path, "a+b") as handle:
|
||
_acquire_lock(handle)
|
||
try:
|
||
yield
|
||
finally:
|
||
_release_lock(handle)
|
||
|
||
|
||
def _write_bytes_atomic(file_path: str, content: bytes) -> None:
|
||
"""原子写入文件内容。"""
|
||
parent_dir = os.path.dirname(file_path) or "."
|
||
os.makedirs(parent_dir, exist_ok=True)
|
||
|
||
fd, tmp_path = tempfile.mkstemp(
|
||
suffix=os.path.splitext(file_path)[1] or ".tmp",
|
||
dir=parent_dir,
|
||
)
|
||
try:
|
||
with os.fdopen(fd, "wb") as f:
|
||
f.write(content)
|
||
os.replace(tmp_path, file_path)
|
||
except Exception:
|
||
try:
|
||
os.remove(tmp_path)
|
||
except OSError:
|
||
pass
|
||
raise
|
||
|
||
|
||
def _snapshot_file(file_path: str) -> str:
|
||
"""复制文件快照,供下载接口在释放锁后返回。"""
|
||
suffix = os.path.splitext(file_path)[1] or ".tmp"
|
||
fd, snapshot_path = tempfile.mkstemp(suffix=suffix)
|
||
try:
|
||
with os.fdopen(fd, "wb") as dst, open(file_path, "rb") as src:
|
||
shutil.copyfileobj(src, dst)
|
||
except Exception:
|
||
try:
|
||
os.remove(snapshot_path)
|
||
except OSError:
|
||
pass
|
||
raise
|
||
return snapshot_path
|
||
|
||
|
||
def _cleanup_temp_file(file_path: str) -> None:
|
||
try:
|
||
os.remove(file_path)
|
||
except OSError:
|
||
pass
|
||
|
||
|
||
def _compute_file_etag(file_path: str) -> str:
|
||
"""计算文件内容的 SHA-256 哈希,作为并发控制的 ETag。"""
|
||
h = hashlib.sha256()
|
||
with open(file_path, "rb") as f:
|
||
for chunk in iter(lambda: f.read(65536), b""):
|
||
h.update(chunk)
|
||
return h.hexdigest()
|
||
|
||
|
||
# 内部版本注册表:记录每个文件最后一次 upload 或 edit_docx 之后的 etag。
|
||
# 所有读写必须持有对应文件的 _file_lock,无需额外线程锁。
|
||
_file_etag_registry: Dict[str, str] = {}
|
||
|
||
|
||
def _register_etag(abs_path: str, etag: str) -> None:
|
||
_file_etag_registry[abs_path] = etag
|
||
|
||
|
||
def _check_etag(abs_path: str) -> None:
|
||
"""
|
||
在文件锁内调用:若注册表中存在该文件的 etag,则校验当前磁盘文件是否匹配。
|
||
不匹配说明文件在本次操作排队期间已被其他操作(如并发 upload)修改。
|
||
"""
|
||
known = _file_etag_registry.get(abs_path)
|
||
if not known:
|
||
return
|
||
current = _compute_file_etag(abs_path)
|
||
if current != known:
|
||
raise ValueError(
|
||
f"文件已被其他操作修改(版本冲突),请确认最新上传后重试。"
|
||
f"已知: {known[:12]}…,当前: {current[:12]}…"
|
||
)
|
||
|
||
|
||
def _validate_docx_file(file_path: str) -> None:
|
||
if not os.path.exists(file_path):
|
||
raise FileNotFoundError(f"输入 DOCX 文件不存在: {file_path}")
|
||
if not os.path.isfile(file_path):
|
||
raise FileNotFoundError(f"输入路径不是文件: {file_path}")
|
||
|
||
size = os.path.getsize(file_path)
|
||
if size <= 0:
|
||
raise ValueError(f"输入 DOCX 文件为空: {os.path.basename(file_path)}")
|
||
if not zipfile.is_zipfile(file_path):
|
||
raise ValueError(f"输入文件不是合法的 DOCX/ZIP: {os.path.basename(file_path)}")
|
||
|
||
|
||
def _resolve_edit_target_path(input_docx_path: str, upload_dir: str) -> str:
|
||
"""
|
||
将编辑输入统一解析为本地路径。
|
||
|
||
- URL 输入会先下载到 uploads 目录
|
||
- 相对路径按 uploads 下的文件名处理
|
||
- 绝对路径直接使用
|
||
"""
|
||
if _is_url(input_docx_path):
|
||
filename = _filename_from_url(input_docx_path)
|
||
return os.path.join(upload_dir, filename)
|
||
|
||
if os.path.isabs(input_docx_path):
|
||
return input_docx_path
|
||
|
||
return os.path.join(upload_dir, _safe_filename(input_docx_path))
|
||
|
||
|
||
@mcp.tool()
|
||
async def list_docx_images(docx_url: str) -> List[Dict[str, Any]]:
|
||
"""
|
||
列出指定 DOCX 文件中的所有图片信息。
|
||
|
||
参数:
|
||
- docx_url: 文件的HTTP/HTTPS URL。
|
||
|
||
返回:
|
||
- 图片信息列表,每一项包含:
|
||
- index: 图片在文档中的顺序(从 1 开始)
|
||
- media_file: DOCX 内部的资源路径
|
||
- ext: 图片扩展名
|
||
- docpr_name: Word 内部的图片名称
|
||
- width_cm / height_cm: 近似尺寸(厘米),可能为 None
|
||
"""
|
||
imgs = get_images_info(_download_to_temp(docx_url, suffix=".docx"))
|
||
return imgs
|
||
|
||
|
||
def _edit_docx_core(
|
||
input_docx_path: str,
|
||
replacements: Optional[List[Dict[str, Any]]],
|
||
report_type: Optional[str],
|
||
report_title_time: Optional[str],
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
对 DOCX 文件进行编辑(与 HTTP /edit_docx 共用逻辑)。
|
||
|
||
返回:
|
||
- {"output_path": 绝对路径, "output_url": URL 或 None, "etag": 新文件哈希}
|
||
"""
|
||
print(f"edit_docx: input_docx_path: {input_docx_path}, replacements: {replacements}")
|
||
upload_dir = _get_upload_dir()
|
||
local_input = _resolve_edit_target_path(input_docx_path, upload_dir)
|
||
lock_cm = _file_lock(local_input)
|
||
lock_cm.__enter__()
|
||
try:
|
||
|
||
if _is_url(input_docx_path):
|
||
_download_to_path(input_docx_path, local_input)
|
||
|
||
_validate_docx_file(local_input)
|
||
|
||
# 版本校验:在锁内对比注册表 etag,检测并发 upload 导致的版本冲突
|
||
_check_etag(os.path.abspath(local_input))
|
||
|
||
if replacements is None:
|
||
replacements = []
|
||
|
||
rep_pairs = []
|
||
color_keywords = []
|
||
for item in replacements:
|
||
old = item.get("old")
|
||
new_raw = item.get("new")
|
||
if not old:
|
||
continue
|
||
old = _normalize_newlines(old)
|
||
if new_raw is None:
|
||
new_raw = ""
|
||
else:
|
||
new_raw = _normalize_newlines(new_raw)
|
||
new_plain, spans = _parse_span_replacement(new_raw)
|
||
rep_pairs.append((old, new_plain))
|
||
color_keywords.extend(spans)
|
||
|
||
parent_dir = os.path.dirname(local_input) or "."
|
||
fd, output_docx = tempfile.mkstemp(suffix=".docx", dir=parent_dir)
|
||
os.close(fd)
|
||
process(
|
||
input_docx=local_input,
|
||
output_docx=output_docx,
|
||
replacements=rep_pairs,
|
||
color_keywords=color_keywords,
|
||
)
|
||
|
||
if report_type or report_title_time:
|
||
try:
|
||
_apply_report_date_logic_to_docx(
|
||
output_docx,
|
||
report_type=report_type,
|
||
report_title_time=report_title_time,
|
||
)
|
||
except Exception as e:
|
||
print(f"apply report date logic failed: {e}")
|
||
|
||
os.replace(output_docx, local_input)
|
||
abs_out = os.path.abspath(local_input)
|
||
new_etag = _compute_file_etag(abs_out)
|
||
_register_etag(abs_out, new_etag)
|
||
|
||
return {
|
||
"output_path": abs_out,
|
||
"output_url": _build_output_url(abs_out),
|
||
"etag": new_etag,
|
||
}
|
||
except Exception:
|
||
if 'output_docx' in locals() and os.path.exists(output_docx):
|
||
try:
|
||
os.remove(output_docx)
|
||
except OSError:
|
||
pass
|
||
raise
|
||
finally:
|
||
lock_cm.__exit__(None, None, None)
|
||
|
||
|
||
@mcp.tool()
|
||
def edit_docx(
|
||
input_docx_path: str,
|
||
replacements: List[Dict[str, Any]],
|
||
report_type: str,
|
||
report_title_time: str,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
对 DOCX 进行文本替换 / 关键字上色。
|
||
|
||
支持:
|
||
- 纯文本替换
|
||
- 通过 <span color="red">关键字</span> 语法设置关键字颜色
|
||
- 报告日期与期数自动替换(仅在“目录”之前生效)
|
||
|
||
参数:
|
||
- input_docx_path: 输入文件名
|
||
- replacements: 替换规则列表,例如
|
||
[{"old": "原文", "new": "<span color='red'>新文</span>"}]
|
||
- report_type: 日报/周报/月报 或 daily/weekly/monthly
|
||
- report_title_time: 替换标题中「YYYY年M月」为指定字符串(首次匹配)
|
||
|
||
返回:
|
||
- 成功: {"success": true, "output_path": ..., "output_url": ...}
|
||
- 失败: {"success": false, "message": "..."}
|
||
"""
|
||
try:
|
||
out = _edit_docx_core(
|
||
input_docx_path,
|
||
replacements,
|
||
report_type,
|
||
report_title_time,
|
||
)
|
||
return {"success": True, **out}
|
||
except Exception as e:
|
||
return {"success": False, "message": str(e)}
|
||
|
||
|
||
@mcp.custom_route("/edit_docx", methods=["POST"])
|
||
async def edit_docx_handler(request: Request):
|
||
try:
|
||
data = await request.json()
|
||
except json.JSONDecodeError:
|
||
return JSONResponse(
|
||
{
|
||
"success": False,
|
||
"message": (
|
||
"请求体必须是合法的 JSON 对象。"
|
||
"请使用 Content-Type: application/json,并发送非空的 JSON body"
|
||
"(空 body、form-data 或 urlencoded 会导致此错误)。"
|
||
),
|
||
},
|
||
status_code=400,
|
||
)
|
||
input_docx_path = data.get("input_docx_path")
|
||
replacements = data.get("replacements")
|
||
report_type = data.get("report_type")
|
||
report_title_time = data.get("report_title_time")
|
||
try:
|
||
result = _edit_docx_core(
|
||
input_docx_path,
|
||
replacements,
|
||
report_type,
|
||
report_title_time,
|
||
)
|
||
return JSONResponse(result)
|
||
except Exception as e:
|
||
return JSONResponse({"success": False, "message": str(e)}, status_code=500)
|
||
|
||
def _get_log_path() -> str:
|
||
"""
|
||
获取日志文件路径。
|
||
|
||
优先使用环境变量 MCP_LOG_FILE(完整路径),否则使用当前目录下的 logs/mcp.log。
|
||
"""
|
||
log_file = os.getenv("MCP_LOG_FILE", "./logs/mcp.log")
|
||
log_path = os.path.abspath(log_file)
|
||
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
||
return log_path
|
||
|
||
@mcp.custom_route("/log", methods=["POST"])
|
||
async def append_log(request: Request):
|
||
"""
|
||
将一段字符串追加写入日志文件,每行带时间戳。
|
||
|
||
参数:
|
||
- message: 要写入的字符串内容。
|
||
|
||
返回:
|
||
- JSON 格式:
|
||
{"success": True/False, "log_path": 日志文件路径, "message": 说明}
|
||
"""
|
||
try:
|
||
# data = await request.json()
|
||
data = await request.body()
|
||
if not data:
|
||
return JSONResponse(
|
||
{
|
||
"success": False,
|
||
"message": "未提供消息内容",
|
||
},
|
||
status_code=400,
|
||
)
|
||
log_path = _get_log_path()
|
||
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
line = f"[{ts}] {data.decode('utf-8')}\n"
|
||
with open(log_path, "a", encoding="utf-8") as f:
|
||
f.write(line)
|
||
return JSONResponse(
|
||
{
|
||
"success": True,
|
||
"log_path": log_path,
|
||
"message": "已写入日志",
|
||
}
|
||
)
|
||
except Exception as e:
|
||
return JSONResponse(
|
||
{
|
||
"success": False,
|
||
"log_path": _get_log_path(),
|
||
"message": f"写入日志失败: {str(e)}",
|
||
},
|
||
status_code=500,
|
||
)
|
||
@mcp.custom_route("/upload", methods=["POST"])
|
||
async def upload_handler(request: Request):
|
||
"""处理文件上传"""
|
||
try:
|
||
form = await request.form()
|
||
file = form.get("file")
|
||
|
||
if not file:
|
||
return JSONResponse({
|
||
"success": False,
|
||
"message": "未提供文件"
|
||
}, status_code=400)
|
||
|
||
upload_dir = _get_upload_dir()
|
||
orig_filename = file.filename or "uploaded.docx"
|
||
|
||
# 安全检查:防止路径遍历攻击,保留原始文件名
|
||
filename = _safe_filename(orig_filename)
|
||
file_path = os.path.join(upload_dir, filename)
|
||
|
||
# 保存文件到 uploads 目录(如已存在则覆盖)
|
||
content = await file.read()
|
||
if not content:
|
||
return JSONResponse({
|
||
"success": False,
|
||
"message": f"上传文件为空: {filename}"
|
||
}, status_code=400)
|
||
abs_file_path = os.path.abspath(file_path)
|
||
with _file_lock(file_path):
|
||
_write_bytes_atomic(file_path, content)
|
||
_register_etag(abs_file_path, _compute_file_etag(abs_file_path))
|
||
|
||
return JSONResponse({
|
||
"success": True,
|
||
"filename": filename, # 保留原始文件名,供 edit_docx 使用
|
||
"file_path": file_path, # 绝对路径(可选)
|
||
"file_url": _build_output_url(file_path),
|
||
"size": len(content),
|
||
"message": f"文件上传成功: {filename}"
|
||
})
|
||
except Exception as e:
|
||
return JSONResponse({
|
||
"success": False,
|
||
"message": f"文件上传失败: {str(e)}"
|
||
}, status_code=500)
|
||
@mcp.custom_route("/download", methods=["GET"])
|
||
@mcp.custom_route("/download/{filename}", methods=["GET"])
|
||
async def download_handler(request: Request):
|
||
"""处理文件下载"""
|
||
try:
|
||
filename = (
|
||
request.path_params.get("filename")
|
||
or request.query_params.get("filename")
|
||
or request.query_params.get("fileName")
|
||
or request.query_params.get("name")
|
||
)
|
||
upload_dir = _get_upload_dir()
|
||
download_filename = (
|
||
request.query_params.get("download_filename")
|
||
or request.query_params.get("new_filename")
|
||
or request.query_params.get("rename_filename")
|
||
)
|
||
|
||
if not filename:
|
||
return JSONResponse({
|
||
"success": False,
|
||
"message": "缺少 filename 参数"
|
||
}, status_code=400)
|
||
|
||
# 安全检查:防止路径遍历攻击
|
||
filename = _safe_filename(filename)
|
||
file_path = os.path.join(upload_dir, filename)
|
||
lock_cm = _file_lock(file_path)
|
||
lock_cm.__enter__()
|
||
try:
|
||
if not os.path.exists(file_path):
|
||
return JSONResponse({
|
||
"success": False,
|
||
"message": f"文件不存在: {filename}"
|
||
}, status_code=404)
|
||
|
||
if not os.path.isfile(file_path):
|
||
return JSONResponse({
|
||
"success": False,
|
||
"message": f"不是文件: {filename}"
|
||
}, status_code=400)
|
||
|
||
snapshot_path = _snapshot_file(file_path)
|
||
finally:
|
||
lock_cm.__exit__(None, None, None)
|
||
|
||
return FileResponse(
|
||
snapshot_path,
|
||
filename=_safe_filename(download_filename, default=filename),
|
||
media_type="application/octet-stream",
|
||
background=BackgroundTask(_cleanup_temp_file, snapshot_path),
|
||
)
|
||
except Exception as e:
|
||
return JSONResponse({
|
||
"success": False,
|
||
"message": f"文件下载失败: {str(e)}"
|
||
}, status_code=500)
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="DOCX MCP 服务器")
|
||
parser.add_argument(
|
||
"--transport",
|
||
choices=["stdio", "http"],
|
||
default="stdio",
|
||
help="传输方式:stdio(本地)或 http(远程 HTTP /streamable-http)",
|
||
)
|
||
parser.add_argument(
|
||
"--host",
|
||
default="0.0.0.0",
|
||
help="HTTP 模式监听地址(默认 0.0.0.0)",
|
||
)
|
||
parser.add_argument(
|
||
"--port",
|
||
type=int,
|
||
default=8080,
|
||
help="HTTP 模式监听端口(默认 8080)",
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
if args.transport == "http":
|
||
# 保存服务器配置到全局变量
|
||
_server_config["host"] = args.host
|
||
_server_config["port"] = args.port
|
||
_server_config["transport"] = "http"
|
||
|
||
# 启动 MCP 服务器(会自动集成到 uvicorn)
|
||
mcp.settings.host = args.host
|
||
mcp.settings.port = args.port
|
||
|
||
# 将自定义路由注入到 MCP 服务器
|
||
print(f"🚀 MCP HTTP 服务器启动中 → http://{args.host}:{args.port}/mcp")
|
||
|
||
# 注意:FastMCP 使用 Starlette,我们需要扩展其路由
|
||
mcp.run(transport="streamable-http")
|
||
else:
|
||
# 本地 stdio 模式
|
||
_server_config["transport"] = "stdio"
|
||
print("🚀 MCP stdio 模式启动中(本地使用)")
|
||
mcp.run(transport="stdio")
|