diff --git a/mcp_docx.py b/mcp_docx.py index 8b09e20..76d63bc 100644 --- a/mcp_docx.py +++ b/mcp_docx.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ -docx_editor.py — 保留原格式替换文本 + 修改字体颜色 + 替换图片 +docx_editor.py — 保留原格式替换文本 + 修改字体颜色 用法: # 列出文档中所有图片 @@ -10,25 +10,14 @@ docx_editor.py — 保留原格式替换文本 + 修改字体颜色 + 替换图 python3 docx_editor.py input.docx output.docx \ --replace "原文" "新文" \ --color "关键词" "FF0000" - - # 图片替换(按文档中出现的顺序,从1开始) - python3 docx_editor.py input.docx output.docx \ - --image 1 new_chart.png \ - --image 2 new_photo.jpg - - # 同时替换文字和图片 - python3 docx_editor.py input.docx output.docx \ - --replace "旧标题" "新标题" \ - --image 1 new_image.png \ - --color "重点" "FF0000" """ import argparse +import copy import os import tempfile import zipfile from lxml import etree -from PIL import Image import re W = 'http://schemas.openxmlformats.org/wordprocessingml/2006/main' @@ -37,12 +26,6 @@ A = 'http://schemas.openxmlformats.org/drawingml/2006/main' R = 'http://schemas.openxmlformats.org/officeDocument/2006/relationships' REL_TYPE_IMAGE = 'http://schemas.openxmlformats.org/officeDocument/2006/relationships/image' -EXT_TO_MIME = { - 'png': 'image/png', 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg', - 'gif': 'image/gif', 'bmp': 'image/bmp', 'tiff': 'image/tiff', - 'webp': 'image/webp', -} - def unpack(docx_path, out_dir): """使用 zipfile 直接解包 .docx 到临时目录,替代外部 unpack.py 脚本。""" @@ -141,61 +124,146 @@ def get_images_info(docx_path): return build_image_index(tmpdir) -def replace_image(unpacked_dir, index, new_image_path): - """替换第 index 张图片(1-based)""" - imgs = build_image_index(unpacked_dir) - if index < 1 or index > len(imgs): - raise ValueError(f"图片序号 {index} 超出范围(共 {len(imgs)} 张)") +def _normalize_newlines(text): + if text is None: + return '' + return str(text).replace('\r\n', '\n').replace('\r', '\n') - info = imgs[index - 1] - old_abs = info['abs_path'] - old_ext = info['ext'] - new_ext = os.path.splitext(new_image_path)[1].lstrip('.').lower() - if new_ext == 'jpg': - new_ext = 'jpeg' - print(f" 图片#{index} {os.path.basename(info['media_file'])}({old_ext.upper()})" - f" ← {os.path.basename(new_image_path)}({new_ext.upper()})") +def _is_text_node(el): + return el.tag == f'{{{W}}}t' - if old_ext == new_ext: - # ── 同格式:直接覆盖 ────────────────────────────── - import shutil - shutil.copy2(new_image_path, old_abs) +def _is_break_node(el): + return el.tag in (f'{{{W}}}br', f'{{{W}}}cr') + + +def _is_tab_node(el): + return el.tag == f'{{{W}}}tab' + + +def _iter_run_text_parts(run_el): + for child in run_el: + if _is_text_node(child): + yield child, _normalize_newlines(child.text or '') + elif _is_break_node(child): + yield child, '\n' + elif _is_tab_node(child): + yield child, '\t' + + +def _run_text(run_el): + return ''.join(part for _, part in _iter_run_text_parts(run_el)) + + +def _paragraph_text(para_el): + return ''.join(_run_text(run) for run in para_el.iter(f'{{{W}}}r')) + + +def _clear_run_text_like_children(run_el): + for child in list(run_el): + if _is_text_node(child) or _is_break_node(child) or _is_tab_node(child): + run_el.remove(child) + + +def _append_text_to_run(run_el, text): + text = _normalize_newlines(text) + parts = text.split('\n') + + if len(parts) == 1: + t_el = etree.SubElement(run_el, f'{{{W}}}t') + t_el.text = parts[0] + if parts[0] and (parts[0][0] == ' ' or parts[0][-1] == ' '): + t_el.set('{http://www.w3.org/XML/1998/namespace}space', 'preserve') + return + + for idx, part in enumerate(parts): + if part: + t_el = etree.SubElement(run_el, f'{{{W}}}t') + t_el.text = part + if part[0] == ' ' or part[-1] == ' ': + t_el.set('{http://www.w3.org/XML/1998/namespace}space', 'preserve') + if idx < len(parts) - 1: + etree.SubElement(run_el, f'{{{W}}}br') + + +def _ensure_paragraph_run(para_el): + runs = list(para_el.findall(f'.//{{{W}}}r')) + if runs: + return runs[0] + + ppr = para_el.find(f'{{{W}}}pPr') + new_r = etree.Element(f'{{{W}}}r') + if ppr is None: + para_el.insert(0, new_r) else: - # ── 不同格式:Pillow 转换 + 更新 rels + ContentTypes - new_abs = os.path.splitext(old_abs)[0] + '.' + new_ext - img = Image.open(new_image_path) - fmt = {'jpeg': 'JPEG', 'png': 'PNG', 'gif': 'GIF', - 'bmp': 'BMP', 'tiff': 'TIFF', 'webp': 'WEBP'}.get(new_ext, new_ext.upper()) - if fmt == 'JPEG' and img.mode in ('RGBA', 'P'): - img = img.convert('RGB') - img.save(new_abs, format=fmt) - if os.path.abspath(new_abs) != os.path.abspath(old_abs): - os.remove(old_abs) + para_el.insert(para_el.index(ppr) + 1, new_r) + return new_r - # 更新 rels - old_media = info['media_file'] - new_media = os.path.splitext(old_media)[0] + '.' + new_ext - word_dir = os.path.join(unpacked_dir, 'word') - rels_path = os.path.join(word_dir, '_rels', 'document.xml.rels') - rels_tree = etree.parse(rels_path) - for rel in rels_tree.getroot(): - if rel.get('Id') == info['rid']: - rel.set('Target', new_media) - break - rels_tree.write(rels_path, xml_declaration=True, encoding='UTF-8', standalone=True) - # 更新 ContentTypes - ct_path = os.path.join(unpacked_dir, '[Content_Types].xml') - ct_tree = etree.parse(ct_path) - ct_root = ct_tree.getroot() - existing = {el.get('Extension', '') for el in ct_root} - if new_ext not in existing: - etree.SubElement(ct_root, 'Default', Extension=new_ext, - ContentType=EXT_TO_MIME.get(new_ext, f'image/{new_ext}')) - ct_tree.write(ct_path, xml_declaration=True, encoding='UTF-8', standalone=True) - print(f" 格式转换 {old_ext}→{new_ext},rels 和 ContentTypes 已更新") +def _set_paragraph_text(para_el, text): + runs = list(para_el.findall(f'.//{{{W}}}r')) + text_runs = [run for run in runs if any(True for _ in _iter_run_text_parts(run))] + + if text_runs: + first_run = text_runs[0] + for run in text_runs: + _clear_run_text_like_children(run) + else: + first_run = _ensure_paragraph_run(para_el) + _clear_run_text_like_children(first_run) + + _append_text_to_run(first_run, text) + + +def _paragraph_list(doc_el): + return list(doc_el.iter(f'{{{W}}}p')) + + +def _replace_paragraph_block(doc_el, old_text, new_text): + old_segments = _normalize_newlines(old_text).split('\n\n') + new_segments = _normalize_newlines(new_text).split('\n\n') + if len(old_segments) <= 1: + return False + + paras = _paragraph_list(doc_el) + para_texts = [_paragraph_text(p) for p in paras] + + match_start = None + for i in range(0, len(para_texts) - len(old_segments) + 1): + if para_texts[i:i + len(old_segments)] == old_segments: + match_start = i + break + + if match_start is None: + return False + + matched_paras = paras[match_start:match_start + len(old_segments)] + parent = matched_paras[0].getparent() + if parent is None: + return False + + anchor_index = parent.index(matched_paras[-1]) + + shared_count = min(len(matched_paras), len(new_segments)) + for idx in range(shared_count): + _set_paragraph_text(matched_paras[idx], new_segments[idx]) + + if len(new_segments) > len(matched_paras): + template_para = matched_paras[-1] + insert_at = anchor_index + 1 + for seg in new_segments[len(matched_paras):]: + new_para = copy.deepcopy(template_para) + _set_paragraph_text(new_para, seg) + parent.insert(insert_at, new_para) + insert_at += 1 + elif len(new_segments) < len(matched_paras): + for para in matched_paras[len(new_segments):]: + para_parent = para.getparent() + if para_parent is not None: + para_parent.remove(para) + + return True def paragraph_replace(para_el, replacements): @@ -213,20 +281,27 @@ def paragraph_replace(para_el, replacements): return # 收集所有文本元素及其位置信息 - t_elements = [] + text_runs = [] for run in runs: - for t_el in run.findall(f'{{{W}}}t'): - t_elements.append((run, t_el)) + if any(True for _ in _iter_run_text_parts(run)): + text_runs.append(run) - if not t_elements: + if not text_runs: return # 拼接完整文本 - full_text = ''.join(t_el.text or '' for _, t_el in t_elements) + full_text = _paragraph_text(para_el) original_text = full_text - # 执行所有替换 + normalized_replacements = [] for old, new in replacements: + normalized_replacements.append(( + _normalize_newlines(old), + _normalize_newlines(new), + )) + + # 执行所有替换 + for old, new in normalized_replacements: if old in full_text: full_text = full_text.replace(old, new) @@ -236,16 +311,11 @@ def paragraph_replace(para_el, replacements): print(f"段落替换: {len(original_text)} -> {len(full_text)} 字符") - # 将新文本重新分配到原有的 元素中 - # 策略:将所有文本放入第一个元素,清空其他元素,避免不当切分导致换行 - _, first_t_el = t_elements[0] - first_t_el.text = full_text - if full_text and (full_text[0] == ' ' or full_text[-1] == ' '): - first_t_el.set('{http://www.w3.org/XML/1998/namespace}space', 'preserve') - - # 清空其他 元素 - for i in range(1, len(t_elements)): - t_elements[i][1].text = '' + # 将规范化文本重新写回第一个文本 run,\n 会回写成 Word 的换行节点。 + first_run = text_runs[0] + for run in text_runs: + _clear_run_text_like_children(run) + _append_text_to_run(first_run, full_text) def ensure_rpr(run_el): @@ -271,13 +341,15 @@ def apply_color_to_keyword(doc_el, keyword, hex_color, context_text=None): 当 context_text 不为空时,只在“整段文本包含该 context_text 的段落”中进行上色, 避免同一个关键字在其他段落里被误伤(例如单独的数字 0)。 """ + keyword = _normalize_newlines(keyword) + context_text = _normalize_newlines(context_text) if context_text is not None else None + # 如果提供了上下文,只在包含该上下文的段落内着色 allowed_paras = None if context_text: allowed_paras = set() for p in doc_el.iter(f'{{{W}}}p'): - t_nodes = list(p.iter(f'{{{W}}}t')) - full = ''.join(t.text or '' for t in t_nodes) + full = _paragraph_text(p) if context_text in full: allowed_paras.add(p) @@ -294,10 +366,9 @@ def apply_color_to_keyword(doc_el, keyword, hex_color, context_text=None): para = _find_ancestor_para(run) if para not in allowed_paras: continue - t_nodes = list(run.findall(f'{{{W}}}t')) - if not t_nodes: + full_text = _run_text(run) + if not full_text: continue - full_text = ''.join(t.text or '' for t in t_nodes) if keyword not in full_text: continue @@ -317,10 +388,7 @@ def apply_color_to_keyword(doc_el, keyword, hex_color, context_text=None): new_r = etree.Element(f'{{{W}}}r') if rpr_bytes is not None: new_r.append(etree.fromstring(rpr_bytes)) - t_el = etree.SubElement(new_r, f'{{{W}}}t') - t_el.text = text - if text and (text[0] == ' ' or text[-1] == ' '): - t_el.set('{http://www.w3.org/XML/1998/namespace}space', 'preserve') + _append_text_to_run(new_r, text) if colored: set_color_on_rpr(ensure_rpr(new_r), hex_color) return new_r @@ -405,19 +473,13 @@ def remove_rule_blocks(doc_el): if parent is not None: parent.remove(p) -def process(input_docx, output_docx, replacements, image_replacements, - color_keywords): +def process(input_docx, output_docx, replacements, color_keywords): with tempfile.TemporaryDirectory() as tmpdir: print(f"📂 解包 {input_docx} ...") unpack(input_docx, tmpdir) doc_xml_path = os.path.join(tmpdir, 'word', 'document.xml') - if image_replacements: - print(f"🖼️ 替换 {len(image_replacements)} 张图片...") - for idx, new_img in image_replacements: - replace_image(tmpdir, idx, new_img) - tree = etree.parse(doc_xml_path) root = tree.getroot() @@ -426,8 +488,17 @@ def process(input_docx, output_docx, replacements, image_replacements, if replacements: print(f"✏️ 替换 {len(replacements)} 条文本...") - for para in root.iter(f'{{{W}}}p'): - paragraph_replace(para, replacements) + remaining_replacements = [] + for old, new in replacements: + if '\n\n' in _normalize_newlines(old): + replaced = _replace_paragraph_block(root, old, new) + if replaced: + print("🧩 跨段替换命中") + continue + remaining_replacements.append((old, new)) + if remaining_replacements: + for para in root.iter(f'{{{W}}}p'): + paragraph_replace(para, remaining_replacements) # 根据 span 解析出的关键字上色 for item in color_keywords: @@ -457,6 +528,8 @@ def _parse_span_replacement(new_text): """ import re + new_text = _normalize_newlines(new_text) + # 简单的命名颜色到 16 进制的映射,可按需扩展 named_colors = { 'red': 'FF0000', @@ -505,33 +578,33 @@ def _parse_span_replacement(new_text): re.IGNORECASE | re.DOTALL, ) - # 先得到去掉 span 标签后的纯文本(也是最终会写入 DOCX 的内容) + # 先按段落边界拆分,这样 span 上色时可以使用所在段落作为上下文。 def _strip_repl(m): return m.group(2) - plain_text = span_pattern.sub(_strip_repl, new_text) - - # 再次遍历 span,收集颜色关键字,并把“整句纯文本”作为上下文挂在每个关键字上 + plain_segments = [] color_keywords = [] - for m in span_pattern.finditer(new_text): - raw_color = m.group(1) - hex_color = _normalize_color(raw_color) - keyword = m.group(2) - # 三元组: (关键字, 颜色, 该 NEW 对应的整句纯文本上下文) - color_keywords.append((keyword, hex_color, plain_text)) + for segment in new_text.split('\n\n'): + plain_segment = span_pattern.sub(_strip_repl, segment) + plain_segments.append(plain_segment) + for m in span_pattern.finditer(segment): + raw_color = m.group(1) + hex_color = _normalize_color(raw_color) + keyword = m.group(2) + # 三元组: (关键字, 颜色, 所在段落的纯文本上下文) + color_keywords.append((keyword, hex_color, plain_segment)) + plain_text = '\n\n'.join(plain_segments) return plain_text, color_keywords def main(): - parser = argparse.ArgumentParser(description='DOCX 格式保留:替换文本/图片/颜色') + parser = argparse.ArgumentParser(description='DOCX 格式保留:替换文本/颜色') parser.add_argument('input', help='输入 .docx') parser.add_argument('output', nargs='?', help='输出 .docx') parser.add_argument('--list-images', action='store_true', help='列出所有图片') parser.add_argument('--replace', nargs=2, metavar=('OLD', 'NEW'), action='append', default=[]) - parser.add_argument('--image', nargs=2, metavar=('INDEX', 'FILE'), - action='append', default=[], help='图片替换') args = parser.parse_args() if args.list_images: @@ -549,12 +622,11 @@ def main(): color_keywords.extend(spans) process( - input_docx = args.input, - output_docx = args.output, - replacements = replacements, - image_replacements= [(int(i), f) for i, f in args.image], - color_keywords = color_keywords, + input_docx=args.input, + output_docx=args.output, + replacements=replacements, + color_keywords=color_keywords, ) if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/mcp_docx_server.py b/mcp_docx_server.py index 364c4f4..b3986c6 100644 --- a/mcp_docx_server.py +++ b/mcp_docx_server.py @@ -2,9 +2,9 @@ """ 基于 mcp_docx.py 封装的 MCP 服务器。 -暴露两个主要 MCP 工具: +暴露主要 MCP 工具: - list_docx_images:列出 DOCX 中的图片信息 -- edit_docx: 进行文本替换 / 关键字上色 / 图片替换 +- edit_docx: 进行文本替换 / 关键字上色(与 HTTP POST /edit_docx 能力一致) 额外提供 HTTP 文件接口(仅在 http 模式下可用): - POST /upload: 上传文件到服务器 @@ -29,12 +29,17 @@ """ import argparse +import hashlib +import json import os +import shutil import tempfile +import time import urllib.parse +import zipfile +from contextlib import contextmanager from datetime import datetime, date, timedelta from typing import Any, Dict, List, Optional -import uuid import requests from lxml import etree @@ -43,6 +48,7 @@ from mcp.server.transport_security import TransportSecuritySettings from mcp_docx import ( W, + _normalize_newlines, get_images_info, process, _parse_span_replacement, @@ -50,8 +56,18 @@ from mcp_docx import ( unpack, pack, ) +# HTTP 远程模式:添加文件上传下载路由 +from starlette.responses import FileResponse, JSONResponse +from starlette.background import BackgroundTask +from starlette.requests import Request +from starlette.responses import FileResponse, JSONResponse -_disable_dns_rebinding = os.getenv("MCP_DISABLE_HOST_CHECK") == "1" +if os.name == "nt": + import msvcrt +else: + import fcntl + +_disable_dns_rebinding = True if _disable_dns_rebinding: # 参考 python-sdk 官方文档:关闭 DNS rebinding 防护(适合本地或已由外层网关做安全控制的环境) @@ -64,8 +80,8 @@ else: # 如需通过网关 / 域名访问,可在这里追加 allowed_hosts / allowed_origins transport_security = TransportSecuritySettings( enable_dns_rebinding_protection=True, - allowed_hosts=["localhost:*", "127.0.0.1:*", "192.168.10.101:*"], - allowed_origins=["http://localhost:*", "http://127.0.0.1:*","http://192.168.10.101:*"], + allowed_hosts=["localhost:*", "127.0.0.1:*", "192.168.1.13:*","10.150.172.13:*"], + allowed_origins=["http://localhost:*", "http://127.0.0.1:*","http://192.168.1.13:*","http://10.150.172.13:*"], ) @@ -234,6 +250,53 @@ def _download_to_temp(url: str, suffix: str = ".tmp") -> str: return tmp_path +def _safe_filename(filename: Optional[str], default: str = "uploaded.docx") -> str: + """提取安全文件名,避免路径穿越。""" + if not filename: + return default + decoded = urllib.parse.unquote(str(filename)) + safe_name = os.path.basename(decoded).strip() + return safe_name or default + + +def _filename_from_url(url: str, default: str = "uploaded.docx") -> str: + """从 URL 中推断文件名,优先读取 query 参数中的 filename。""" + parsed = urllib.parse.urlparse(url) + query = urllib.parse.parse_qs(parsed.query) + + for key in ("filename", "fileName", "name"): + values = query.get(key) + if values: + return _safe_filename(values[0], default=default) + + return _safe_filename(os.path.basename(parsed.path), default=default) + + +def _download_to_path(url: str, local_path: str) -> None: + """将远程 URL 下载到指定路径,完成后原子覆盖目标文件。""" + resp = requests.get(url, stream=True, timeout=30) + resp.raise_for_status() + + parent_dir = os.path.dirname(local_path) or "." + os.makedirs(parent_dir, exist_ok=True) + fd, tmp_path = tempfile.mkstemp( + suffix=os.path.splitext(local_path)[1] or ".tmp", + dir=parent_dir, + ) + try: + with os.fdopen(fd, "wb") as f: + for chunk in resp.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + os.replace(tmp_path, local_path) + except Exception: + try: + os.remove(tmp_path) + except OSError: + pass + raise + + def _build_output_url(abs_output_path: str) -> Optional[str]: """ 构造输出文件的下载 URL。 @@ -245,18 +308,37 @@ def _build_output_url(abs_output_path: str) -> Optional[str]: - 否则在 http 模式下: http://host:port/download/{filename} - stdio 模式下: 返回 None """ + filename = os.path.basename(abs_output_path) + encoded_filename = urllib.parse.quote(filename) + + def _append_filename(base_url: str) -> str: + parsed = urllib.parse.urlparse(base_url) + query = urllib.parse.parse_qsl(parsed.query, keep_blank_values=True) + + for index, (key, _) in enumerate(query): + if key in ("filename", "fileName", "name"): + query[index] = (key, filename) + return urllib.parse.urlunparse( + parsed._replace(query=urllib.parse.urlencode(query)) + ) + + if parsed.path.rstrip("/").endswith("/download"): + query.append(("filename", filename)) + return urllib.parse.urlunparse( + parsed._replace(query=urllib.parse.urlencode(query)) + ) + + return base_url.rstrip("/") + "/" + encoded_filename + # 优先使用环境变量 base = os.getenv("MCP_OUTPUT_BASE_URL") if base: - filename = os.path.basename(abs_output_path) - return base.rstrip("/") + "/" + filename + return _append_filename(base) # 如果是 http 模式,自动构建下载 URL if _server_config["transport"] == "http": host = _server_config["host"] port = _server_config["port"] - filename = os.path.basename(abs_output_path) - # 如果 host 是 0.0.0.0,尝试使用更具体的地址 if host == "0.0.0.0": # 优先使用环境变量指定的公网地址 @@ -267,7 +349,7 @@ def _build_output_url(abs_output_path: str) -> Optional[str]: # 默认使用 localhost host = "localhost" - return f"http://{host}:{port}/download/{filename}" + return _append_filename(f"http://{host}:{port}/download") return None @@ -294,6 +376,136 @@ def _get_tmp_upload_dir() -> str: return os.path.abspath(tmp_dir) +def _get_lock_dir() -> str: + """获取文件锁目录。""" + lock_dir = os.path.join(_get_upload_dir(), ".locks") + os.makedirs(lock_dir, exist_ok=True) + return lock_dir + + +def _get_lock_path(target_path: str) -> str: + """根据目标文件路径生成稳定的锁文件路径。""" + abs_target = os.path.abspath(target_path) + base_name = _safe_filename(os.path.basename(abs_target), default="file") + digest = hashlib.sha256(abs_target.encode("utf-8")).hexdigest() + return os.path.join(_get_lock_dir(), f"{base_name}.{digest}.lock") + + +def _acquire_lock(handle) -> None: + """跨进程独占锁。""" + handle.seek(0, os.SEEK_END) + if handle.tell() == 0: + handle.write(b"0") + handle.flush() + handle.seek(0) + + if os.name == "nt": + while True: + try: + msvcrt.locking(handle.fileno(), msvcrt.LK_LOCK, 1) + break + except OSError: + time.sleep(0.05) + else: + fcntl.flock(handle.fileno(), fcntl.LOCK_EX) + + +def _release_lock(handle) -> None: + """释放跨进程独占锁。""" + handle.seek(0) + if os.name == "nt": + msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1) + else: + fcntl.flock(handle.fileno(), fcntl.LOCK_UN) + + +@contextmanager +def _file_lock(target_path: str): + """针对目标文件路径获取独占文件锁。""" + lock_path = _get_lock_path(target_path) + with open(lock_path, "a+b") as handle: + _acquire_lock(handle) + try: + yield + finally: + _release_lock(handle) + + +def _write_bytes_atomic(file_path: str, content: bytes) -> None: + """原子写入文件内容。""" + parent_dir = os.path.dirname(file_path) or "." + os.makedirs(parent_dir, exist_ok=True) + + fd, tmp_path = tempfile.mkstemp( + suffix=os.path.splitext(file_path)[1] or ".tmp", + dir=parent_dir, + ) + try: + with os.fdopen(fd, "wb") as f: + f.write(content) + os.replace(tmp_path, file_path) + except Exception: + try: + os.remove(tmp_path) + except OSError: + pass + raise + + +def _snapshot_file(file_path: str) -> str: + """复制文件快照,供下载接口在释放锁后返回。""" + suffix = os.path.splitext(file_path)[1] or ".tmp" + fd, snapshot_path = tempfile.mkstemp(suffix=suffix) + try: + with os.fdopen(fd, "wb") as dst, open(file_path, "rb") as src: + shutil.copyfileobj(src, dst) + except Exception: + try: + os.remove(snapshot_path) + except OSError: + pass + raise + return snapshot_path + + +def _cleanup_temp_file(file_path: str) -> None: + try: + os.remove(file_path) + except OSError: + pass + + +def _validate_docx_file(file_path: str) -> None: + if not os.path.exists(file_path): + raise FileNotFoundError(f"输入 DOCX 文件不存在: {file_path}") + if not os.path.isfile(file_path): + raise FileNotFoundError(f"输入路径不是文件: {file_path}") + + size = os.path.getsize(file_path) + if size <= 0: + raise ValueError(f"输入 DOCX 文件为空: {os.path.basename(file_path)}") + if not zipfile.is_zipfile(file_path): + raise ValueError(f"输入文件不是合法的 DOCX/ZIP: {os.path.basename(file_path)}") + + +def _resolve_edit_target_path(input_docx_path: str, upload_dir: str) -> str: + """ + 将编辑输入统一解析为本地路径。 + + - URL 输入会先下载到 uploads 目录 + - 相对路径按 uploads 下的文件名处理 + - 绝对路径直接使用 + """ + if _is_url(input_docx_path): + filename = _filename_from_url(input_docx_path) + return os.path.join(upload_dir, filename) + + if os.path.isabs(input_docx_path): + return input_docx_path + + return os.path.join(upload_dir, _safe_filename(input_docx_path)) + + @mcp.tool() async def list_docx_images(docx_url: str) -> List[Dict[str, Any]]: """ @@ -313,72 +525,34 @@ async def list_docx_images(docx_url: str) -> List[Dict[str, Any]]: imgs = get_images_info(_download_to_temp(docx_url, suffix=".docx")) return imgs -@mcp.custom_route("/edit_docx", methods=["POST"]) -async def edit_docx_handler(request: Request): - data = await request.json() - input_docx_path = data.get("input_docx_path") - replacements = data.get("replacements") - image_replacements = data.get("image_replacements") - report_type = data.get("report_type") - report_title_time = data.get("report_title_time") + +def _edit_docx_core( + input_docx_path: str, + replacements: Optional[List[Dict[str, Any]]], + report_type: Optional[str], + report_title_time: Optional[str], +) -> Dict[str, Any]: """ - 对 DOCX 文件进行编辑。 - - 支持: - - 纯文本替换 - - 通过 关键字 语法设置关键字颜色 - - 替换指定序号的图片 - - 报告日期与期数自动替换(仅在“目录”之前生效) - - 参数: - - input_docx_path: 输入 DOCX 文件名称 - - replacements: 文本替换规则列表,例如: - [ - {"old": "计划作业总数共有10项。", "new": "计划作业总数共有XX项。"}, - {"old": "文档原文本,必须是完整的一句话或者段落", "new": "要替换的文本"} - ] - - image_replacements: 图片替换规则 - - report_type: 报告类型,可选值:日报 / 周报 / 月报(或对应的英文 daily / weekly / monthly) - - report_title_time: 报告标题中要显示的时间字符串,用来替换“YYYY年M月”这一段(仅在第一次匹配时生效) + 对 DOCX 文件进行编辑(与 HTTP /edit_docx 共用逻辑)。 返回: - - { - "output_path": 生成的 DOCX 绝对路径, - "output_url": 如果配置了 MCP_OUTPUT_BASE_URL,则为可访问该文件的 URL,否则为 null - } + - {"output_path": 绝对路径, "output_url": URL 或 None} """ - tmp_input: Optional[str] = None - tmp_images: List[str] = [] - print(f"edit_docx: input_docx_path: {input_docx_path}, replacements: {replacements}, image_replacements: {image_replacements}") + print(f"edit_docx: input_docx_path: {input_docx_path}, replacements: {replacements}") + upload_dir = _get_upload_dir() + local_input = _resolve_edit_target_path(input_docx_path, upload_dir) + lock_cm = _file_lock(local_input) + lock_cm.__enter__() try: - upload_dir = _get_upload_dir() # 输出目录:/uploads - tmp_upload_dir = _get_tmp_upload_dir() # 上传临时目录:/tmp - # 解析输入路径:支持 URL、绝对路径、仅文件名三种形式 - local_input = input_docx_path if _is_url(input_docx_path): - parsed = urllib.parse.urlparse(input_docx_path) - ext = os.path.splitext(parsed.path)[1] or ".docx" - tmp_input = _download_to_temp(input_docx_path, suffix=ext) - local_input = tmp_input - elif not os.path.isabs(local_input): - # 相对路径:优先在 tmp,其次在 uploads 中查找 - cand_tmp = os.path.join(tmp_upload_dir, input_docx_path) - cand_upload = os.path.join(upload_dir, input_docx_path) - if os.path.exists(cand_tmp): - local_input = cand_tmp - else: - local_input = cand_upload + _download_to_path(input_docx_path, local_input) - if not os.path.exists(local_input): - raise FileNotFoundError(f"输入 DOCX 文件不存在: {input_docx_path}") + _validate_docx_file(local_input) if replacements is None: replacements = [] - if image_replacements is None: - image_replacements = [] - # 解析文本替换与颜色关键字(复用 CLI 逻辑) rep_pairs = [] color_keywords = [] for item in replacements: @@ -386,55 +560,25 @@ async def edit_docx_handler(request: Request): new_raw = item.get("new") if not old: continue + old = _normalize_newlines(old) if new_raw is None: new_raw = "" + else: + new_raw = _normalize_newlines(new_raw) new_plain, spans = _parse_span_replacement(new_raw) rep_pairs.append((old, new_plain)) color_keywords.extend(spans) - # 处理图片替换参数(支持本地路径或 URL) - img_pairs = [] - for item in image_replacements: - try: - idx = int(item.get("index")) - except (TypeError, ValueError): - continue - - path = item.get("file") - if not path: - continue - - local_img = path - if _is_url(path): - parsed = urllib.parse.urlparse(path) - ext = os.path.splitext(parsed.path)[1] or "" - suffix = ext if ext else ".img" - tmp_img = _download_to_temp(path, suffix=suffix) - tmp_images.append(tmp_img) - local_img = tmp_img - - if not os.path.exists(local_img): - raise FileNotFoundError(f"图片文件不存在: {path}") - - img_pairs.append((idx, local_img)) - - # 复用原始处理函数: - # 输出文件统一写入 /uploads 目录,文件名带时间戳和随机后缀避免并发冲突 - base_name = os.path.basename(local_input) - name_root, _ = os.path.splitext(base_name) - ts = datetime.now().strftime('%Y%m%d%H%M%S') - rand = uuid.uuid4().hex[:6] - output_filename = f"{name_root}_output_{ts}_{rand}.docx" - output_docx = os.path.join(upload_dir, output_filename) + parent_dir = os.path.dirname(local_input) or "." + fd, output_docx = tempfile.mkstemp(suffix=".docx", dir=parent_dir) + os.close(fd) process( input_docx=local_input, output_docx=output_docx, replacements=rep_pairs, - image_replacements=img_pairs, color_keywords=color_keywords, ) - # 追加:根据报告类型与标题时间,在“目录”之前自动处理日期和期数 if report_type or report_title_time: try: _apply_report_date_logic_to_docx( @@ -443,41 +587,95 @@ async def edit_docx_handler(request: Request): report_title_time=report_title_time, ) except Exception as e: - # 避免因为日期处理失败而导致整个接口报错,把错误写到日志即可 print(f"apply report date logic failed: {e}") - abs_out = os.path.abspath(output_docx) + os.replace(output_docx, local_input) + abs_out = os.path.abspath(local_input) - # 删除上传的临时文件:只删除位于 tmp 目录中的输入文件 - try: - tmp_root = _get_tmp_upload_dir() - if os.path.exists(local_input): - abs_input = os.path.abspath(local_input) - if os.path.commonpath([abs_input, tmp_root]) == tmp_root: - os.remove(local_input) - except Exception: - # 不因清理失败影响主流程 - pass return { - "output_path": output_docx, - "output_url": _build_output_url(output_docx), + "output_path": abs_out, + "output_url": _build_output_url(abs_out), } - finally: - if tmp_input and os.path.exists(tmp_input): + except Exception: + if 'output_docx' in locals() and os.path.exists(output_docx): try: - os.remove(tmp_input) + os.remove(output_docx) except OSError: pass + raise + finally: + lock_cm.__exit__(None, None, None) + + +@mcp.tool() +def edit_docx( + input_docx_path: str, + replacements: List[Dict[str, Any]], + report_type: str, + report_title_time: str, +) -> Dict[str, Any]: + """ + 对 DOCX 进行文本替换 / 关键字上色。 + + 支持: + - 纯文本替换 + - 通过 关键字 语法设置关键字颜色 + - 报告日期与期数自动替换(仅在“目录”之前生效) + + 参数: + - input_docx_path: 输入文件名 + - replacements: 替换规则列表,例如 + [{"old": "原文", "new": "新文"}] + - report_type: 日报/周报/月报 或 daily/weekly/monthly + - report_title_time: 替换标题中「YYYY年M月」为指定字符串(首次匹配) + + 返回: + - 成功: {"success": true, "output_path": ..., "output_url": ...} + - 失败: {"success": false, "message": "..."} + """ + try: + out = _edit_docx_core( + input_docx_path, + replacements, + report_type, + report_title_time, + ) + return {"success": True, **out} + except Exception as e: + return {"success": False, "message": str(e)} + + +@mcp.custom_route("/edit_docx", methods=["POST"]) +async def edit_docx_handler(request: Request): + try: + data = await request.json() + except json.JSONDecodeError: + return JSONResponse( + { + "success": False, + "message": ( + "请求体必须是合法的 JSON 对象。" + "请使用 Content-Type: application/json,并发送非空的 JSON body" + "(空 body、form-data 或 urlencoded 会导致此错误)。" + ), + }, + status_code=400, + ) + input_docx_path = data.get("input_docx_path") + replacements = data.get("replacements") + report_type = data.get("report_type") + report_title_time = data.get("report_title_time") + try: + result = _edit_docx_core( + input_docx_path, + replacements, + report_type, + report_title_time, + ) + return JSONResponse(result) + except Exception as e: + return JSONResponse({"success": False, "message": str(e)}, status_code=500) - for p in tmp_images: - if os.path.exists(p): - try: - os.remove(p) - except OSError: - pass -# HTTP 远程模式:添加文件上传下载路由 -from starlette.responses import FileResponse, JSONResponse -from starlette.requests import Request def _get_log_path() -> str: """ 获取日志文件路径。 @@ -541,28 +739,33 @@ async def upload_handler(request: Request): file = form.get("file") if not file: - return JSONResponse({ + return JSONResponse({ "success": False, "message": "未提供文件" }, status_code=400) - tmp_dir = _get_tmp_upload_dir() + upload_dir = _get_upload_dir() orig_filename = file.filename or "uploaded.docx" # 安全检查:防止路径遍历攻击,保留原始文件名 - filename = os.path.basename(orig_filename) - file_path = os.path.join(tmp_dir, filename) + filename = _safe_filename(orig_filename) + file_path = os.path.join(upload_dir, filename) - # 保存文件到临时目录(如已存在则覆盖) + # 保存文件到 uploads 目录(如已存在则覆盖) content = await file.read() - with open(file_path, "wb") as f: - f.write(content) + if not content: + return JSONResponse({ + "success": False, + "message": f"上传文件为空: {filename}" + }, status_code=400) + with _file_lock(file_path): + _write_bytes_atomic(file_path, content) return JSONResponse({ "success": True, "filename": filename, # 保留原始文件名,供 edit_docx 使用 "file_path": file_path, # 绝对路径(可选) - "file_url": None, # 临时文件不提供下载 URL + "file_url": _build_output_url(file_path), "size": len(content), "message": f"文件上传成功: {filename}" }) @@ -571,33 +774,57 @@ async def upload_handler(request: Request): "success": False, "message": f"文件上传失败: {str(e)}" }, status_code=500) +@mcp.custom_route("/download", methods=["GET"]) @mcp.custom_route("/download/{filename}", methods=["GET"]) async def download_handler(request: Request): """处理文件下载""" try: - filename = request.path_params.get("filename") + filename = ( + request.path_params.get("filename") + or request.query_params.get("filename") + or request.query_params.get("fileName") + or request.query_params.get("name") + ) upload_dir = _get_upload_dir() + download_filename = ( + request.query_params.get("download_filename") + or request.query_params.get("new_filename") + or request.query_params.get("rename_filename") + ) + + if not filename: + return JSONResponse({ + "success": False, + "message": "缺少 filename 参数" + }, status_code=400) # 安全检查:防止路径遍历攻击 - filename = os.path.basename(filename) + filename = _safe_filename(filename) file_path = os.path.join(upload_dir, filename) - - if not os.path.exists(file_path): - return JSONResponse({ + lock_cm = _file_lock(file_path) + lock_cm.__enter__() + try: + if not os.path.exists(file_path): + return JSONResponse({ "success": False, "message": f"文件不存在: {filename}" }, status_code=404) - if not os.path.isfile(file_path): - return JSONResponse({ + if not os.path.isfile(file_path): + return JSONResponse({ "success": False, "message": f"不是文件: {filename}" }, status_code=400) + snapshot_path = _snapshot_file(file_path) + finally: + lock_cm.__exit__(None, None, None) + return FileResponse( - file_path, - filename=filename, - media_type="application/octet-stream" + snapshot_path, + filename=_safe_filename(download_filename, default=filename), + media_type="application/octet-stream", + background=BackgroundTask(_cleanup_temp_file, snapshot_path), ) except Exception as e: return JSONResponse({