from fastapi import FastAPI, Response
from fastapi.responses import StreamingResponse
from PIL import Image
import io
import random

app = FastAPI(title="Inference API")

def generate_placeholder_image(width=512, height=512, color=None):
    """生成纯色占位图"""
    if color is None:
        # 随机生成颜色
        color = (
            random.randint(0, 255),
            random.randint(0, 255),
            random.randint(0, 255)
        )
    
    img = Image.new('RGB', (width, height), color)
    
    # 保存到字节流
    img_byte_arr = io.BytesIO()
    img.save(img_byte_arr, format='PNG')
    img_byte_arr.seek(0)
    
    return img_byte_arr

@app.get("/")
async def root():
    return {"message": "FastAPI Inference Service", "endpoints": ["/inference"]}

@app.get("/inference")
async def inference():
    """
    模拟推理接口，返回 3 张纯色占位图
    返回包含 3 张图片的响应（使用 multipart 或使用 JSON 返回 base64）
    这里使用多次返回的方式，实际中可以打包成 zip 或分别返回
    """
    images = []
    
    for i in range(3):
        # 生成不同的纯色图
        colors = [
            (255, 100, 100),  # 红色
            (100, 255, 100),  # 绿色
            (100, 100, 255),  # 蓝色
        ]
        img_bytes = generate_placeholder_image(
            width=512, 
            height=512, 
            color=colors[i]
        )
        images.append(img_bytes)
    
    # 返回第一张图片（演示用）
    # 实际场景中你可能想返回所有图片（zip 或 JSON with base64）
    return Response(
        content=images[0].getvalue(),
        media_type="image/png",
        headers={"X-Image-Count": "3", "X-Image-Index": "1"}
    )

@app.get("/inference/all")
async def inference_all():
    """返回所有 3 张图片的 base64 编码"""
    import base64
    
    results = []
    colors = [
        {"name": "red", "rgb": (255, 100, 100)},
        {"name": "green", "rgb": (100, 255, 100)},
        {"name": "blue", "rgb": (100, 100, 255)},
    ]
    
    for i, color_info in enumerate(colors):
        img_bytes = generate_placeholder_image(
            width=512,
            height=512,
            color=color_info["rgb"]
        )
        img_base64 = base64.b64encode(img_bytes.getvalue()).decode('utf-8')
        results.append({
            "index": i + 1,
            "color_name": color_info["name"],
            "color_rgb": color_info["rgb"],
            "image_base64": img_base64[:100] + "..."  # 截断，实际使用时返回完整 base64
        })
    
    return {"count": 3, "images": results}

@app.get("/inference/{index}")
async def inference_by_index(index: int):
    """根据索引返回特定图片 (1-3)"""
    if index < 1 or index > 3:
        return {"error": "Index must be between 1 and 3"}
    
    colors = [
        (255, 100, 100),  # 红色
        (100, 255, 100),  # 绿色
        (100, 100, 255),  # 蓝色
    ]
    
    img_bytes = generate_placeholder_image(
        width=512,
        height=512,
        color=colors[index - 1]
    )
    
    return Response(
        content=img_bytes.getvalue(),
        media_type="image/png",
        headers={"X-Image-Index": str(index)}
    )

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8001)
