import os
import time
import imageio
import whisper
import numpy as np
from PIL import Image
from skimage.metrics import structural_similarity as ssim
import tempfile

# 添加FFmpeg路径(根据你的实际安装路径修改)
os.environ["PATH"] += os.pathsep + r"D:\ffmpeg\bin"  # 例如:D:\ffmpeg\bin
# ============================== 配置参数 ==============================
# 示例:将视频复制到 D:\test\input.mp4
VIDEO_PATH = "D:/python项目文件/1/input2.mp4"  # 输入视频路径
MODEL_DIR = "D:/whisper_models"  # 手动下载的模型存放目录
SSIM_THRESHOLD = 0.85  # 关键帧去重阈值
FRAME_INTERVAL = 2  # 抽帧间隔(秒)
OUTPUT_DIR = "output2"  # 输出目录


# =====================================================================

def extract_keyframes_with_time(video_path: str) -> tuple:
    """改进版关键帧提取(返回关键帧图像列表和时间戳列表)"""
    try:
        # 初始化视频读取器
        reader = imageio.get_reader(video_path, 'ffmpeg')
        fps = reader.get_meta_data().get('fps', 30)
        print(f"视频帧率: {fps}fps, 总时长: {reader.get_meta_data()['duration']:.1f}秒")

        keyframes = []
        keyframe_times = []
        prev_frame = None
        frame_counter = 0

        for i, frame in enumerate(reader):
            # 按间隔抽帧(默认每秒抽帧改为每FRAME_INTERVAL秒抽帧)
            if i % int(fps * FRAME_INTERVAL) != 0:
                continue

            current_time = i / fps
            # 降采样至320x240加速处理
            curr_frame = Image.fromarray(frame).resize((320, 240))

            if prev_frame is None:
                # 首帧强制保留
                keyframes.append(curr_frame)
                keyframe_times.append(current_time)
                prev_frame = np.array(curr_frame.convert('L'))
            else:
                # 计算灰度图SSIM
                curr_gray = np.array(curr_frame.convert('L'))
                score = ssim(prev_frame, curr_gray, data_range=255)

                if score < SSIM_THRESHOLD:
                    keyframes.append(curr_frame)
                    keyframe_times.append(current_time)
                    prev_frame = curr_gray

            frame_counter += 1
            if frame_counter % 10 == 0:
                print(f"已处理 {current_time:.1f}秒...")

        reader.close()
        print(f"关键帧提取完成,共{len(keyframes)}帧")
        return keyframes, keyframe_times
    except Exception as e:
        print(f"视频处理失败: {str(e)}")
        return [], []


def align_text_with_keyframes(video_path: str, keyframe_times: list) -> list:
    try:
        # 1. 动态添加 FFmpeg 路径
        ffmpeg_bin = r"D:\ffmpeg\bin"
        os.environ["PATH"] = ffmpeg_bin + os.pathsep + os.environ["PATH"]

        # 2. 加载模型
        model = whisper.load_model("tiny", device="cpu")

        # 3. 执行语音识别(不再传递 ffmpeg_path)
        result = model.transcribe(video_path, fp16=False)

        # 4. 对齐处理
        alignment = []
        kf_ptr = 0
        for seg in result["segments"]:
            seg_start = seg["start"]
            seg_end = seg["end"]
            matched_time = None
            while kf_ptr < len(keyframe_times):
                if keyframe_times[kf_ptr] <= seg_end:
                    matched_time = keyframe_times[kf_ptr]
                    kf_ptr += 1
                else:
                    break
            if matched_time is not None:
                alignment.append({
                    "text": seg["text"].strip(),
                    "start": seg_start,
                    "end": seg_end,
                    "keyframe_time": matched_time
                })
        return alignment
    except Exception as e:
        print(f"语音处理失败: {str(e)}")
        return []

def save_results(keyframes, alignment):
    """保存关键帧和文本对齐结果"""
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # 保存关键帧
    for i, img in enumerate(keyframes):
        img.save(os.path.join(OUTPUT_DIR, f"frame_{i:04d}.jpg"))

    # 保存对齐文本
    with open(os.path.join(OUTPUT_DIR, "alignment.txt"), "w", encoding="utf-8") as f:
        for item in alignment:
            f.write(
                f"[{item['keyframe_time']:.1f}s] "
                f"({item['start']:.1f}-{item['end']:.1f}s): "
                f"{item['text']}\n"
            )
    print(f"结果已保存至{OUTPUT_DIR}目录")



# 打印临时目录路径并检查可写权限
temp_dir = tempfile.gettempdir()
print(f"临时目录: {temp_dir}")
if not os.access(temp_dir, os.W_OK):
    print("错误:临时目录不可写!")
else:
    print("临时目录可写")

if __name__ == "__main__":
    # 步骤1: 提取关键帧
    keyframes, keyframe_times = extract_keyframes_with_time(VIDEO_PATH)
    if not keyframes:
        exit()

    # 步骤2: 语音对齐
    alignment = align_text_with_keyframes(VIDEO_PATH, keyframe_times)

    # 步骤3: 保存结果
    save_results(keyframes, alignment)