Files

621 lines
21 KiB
Python
Raw Permalink Normal View History

2026-03-25 02:04:52 -05:00
"""
Gateway Checker API FastAPI server.
Run: uvicorn api:app --host 0.0.0.0 --port 8000
"""
import asyncio
import hashlib
import logging
import os
import re
import time
import uuid
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import db
from comwave_auth import checker
from comwave_charge import check_card_3
load_dotenv()
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
log = logging.getLogger("api")
MASTER_KEY = os.getenv("MASTER_KEY", "")
if not MASTER_KEY:
log.warning("MASTER_KEY not set in .env — admin endpoints will reject all requests")
# ── Gateway Registry ──
GATEWAYS = {
"comwave": {
"fn": checker.check_card,
"type": "$0 auth",
"status": "online",
"max_concurrent": 1,
"timeout": 30,
"cooldown": 20,
"max_queue": 50,
"init": checker.ensure_login,
"shutdown": checker.shutdown,
},
"comwave3": {
"fn": check_card_3,
"type": "$3.33 charge",
"status": "online",
"max_concurrent": 5,
"timeout": 60,
"cooldown": 20,
"max_queue": 50,
"init": None,
"shutdown": None,
},
}
# ── In-Memory State ──
tasks: dict[str, dict] = {} # task_id → {api_key, status, result, gateway, created, expires}
queues: dict[str, asyncio.Queue] = {} # gateway → Queue
semaphores: dict[str, asyncio.Semaphore] = {} # gateway → Semaphore
workers: list[asyncio.Task] = []
last_request: dict[str, dict[str, float]] = {} # api_key → {gateway → timestamp} (cooldown)
rate_windows: dict[str, list[float]] = {} # api_key → [timestamps] (rate limit)
failed_ips: dict[str, dict] = {} # ip → {count, blocked_until}
dedup_cache: dict[str, dict] = {} # sha256(card+gw) → {result, expires}
gateway_errors: dict[str, int] = {} # gateway → consecutive error count
avg_times: dict[str, float] = {} # gateway → avg response time
START_TIME = 0.0
# ── Request Models ──
class CheckRequest(BaseModel):
card: str
class CreateKeyRequest(BaseModel):
owner: str
gateways: list[str] | str = "*"
request_limit: int | None = None
expires_days: int | None = None
rate_per_minute: int = 10
class UpdateKeyRequest(BaseModel):
owner: str | None = None
gateways: list[str] | str | None = None
request_limit: int | None = None
rate_per_minute: int | None = None
is_active: bool | None = None
is_paused: bool | None = None
expires_days: int | None = None
# ── Helpers ──
CARD_RE = re.compile(r"^\d{13,19}\|\d{1,2}\|\d{2,4}\|\d{3,4}$")
def validate_card(card: str) -> bool:
return bool(CARD_RE.match(card.strip()))
def get_client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
def check_ip_block(ip: str):
info = failed_ips.get(ip)
if not info:
return
if info.get("blocked_until") and time.time() < info["blocked_until"]:
raise HTTPException(403, {"detail": "IP blocked"})
if info.get("blocked_until") and time.time() >= info["blocked_until"]:
del failed_ips[ip]
def record_failed_auth(ip: str):
info = failed_ips.setdefault(ip, {"count": 0, "blocked_until": None})
info["count"] += 1
if info["count"] >= 10:
info["blocked_until"] = time.time() + 900 # 15 min
log.warning(f"IP blocked: {ip} (10 failed auth attempts)")
async def get_api_key(request: Request) -> dict:
ip = get_client_ip(request)
check_ip_block(ip)
key = request.headers.get("x-api-key", "")
if not key:
record_failed_auth(ip)
raise HTTPException(401, {"detail": "Missing API key"})
record = await db.get_key(key)
if not record:
record_failed_auth(ip)
raise HTTPException(401, {"detail": "Invalid API key"})
if not record["is_active"]:
raise HTTPException(401, {"detail": "API key deactivated"})
if record["is_paused"]:
raise HTTPException(401, {"detail": "API key paused"})
if record["expires_at"]:
exp = datetime.fromisoformat(record["expires_at"])
if datetime.now(timezone.utc) > exp:
raise HTTPException(401, {"detail": "API key expired"})
if record["request_limit"] is not None and record["requests_used"] >= record["request_limit"]:
raise HTTPException(429, {"detail": "Request limit exceeded",
"used": record["requests_used"], "limit": record["request_limit"]})
return record
def check_rate_limit(api_key: str, rate_per_minute: int):
now = time.time()
window = rate_windows.setdefault(api_key, [])
cutoff = now - 60
rate_windows[api_key] = [t for t in window if t > cutoff]
if len(rate_windows[api_key]) >= rate_per_minute:
oldest = min(rate_windows[api_key])
retry_after = 60 - (now - oldest)
raise HTTPException(429, {"detail": "Rate limit exceeded", "retry_after": round(retry_after, 1)})
rate_windows[api_key].append(now)
def check_cooldown(api_key: str, gateway: str, cooldown: float):
now = time.time()
last = last_request.get(api_key, {}).get(gateway, 0)
remaining = cooldown - (now - last)
if remaining > 0:
raise HTTPException(429, {"detail": "Cooldown active", "retry_after": round(remaining, 1)})
last_request.setdefault(api_key, {})[gateway] = now
def check_dedup(card: str, gateway: str) -> dict | None:
h = hashlib.sha256(f"{card}:{gateway}".encode()).hexdigest()
cached = dedup_cache.get(h)
if cached and time.time() < cached["expires"]:
return cached["result"]
if cached:
del dedup_cache[h]
return None
def store_dedup(card: str, gateway: str, result: dict):
h = hashlib.sha256(f"{card}:{gateway}".encode()).hexdigest()
dedup_cache[h] = {"result": result, "expires": time.time() + 60}
def parse_gateway_result(raw: str, gateway: str, elapsed: float) -> dict:
"""Convert gateway string result to API dict format."""
raw_lower = raw.lower()
if raw_lower.startswith("approved") or raw_lower.startswith("charged"):
status = "approved"
elif raw_lower.startswith("declined"):
status = "declined"
elif raw_lower.startswith("rate limited"):
status = "error"
else:
status = "error"
return {"status": status, "gateway": gateway, "message": raw, "time": round(elapsed, 2)}
def get_queue_position(task_id: str, gateway: str) -> int:
"""Estimate position by counting queued tasks for this gateway created before this task."""
task = tasks.get(task_id)
if not task:
return 0
pos = 0
for tid, t in tasks.items():
if t.get("gateway") == gateway and t.get("status") == "queued" and t.get("created", 0) <= task.get("created", 0):
pos += 1
return pos
# ── Workers ──
async def gateway_worker(gateway_name: str):
queue = queues[gateway_name]
sem = semaphores[gateway_name]
gw = GATEWAYS[gateway_name]
while True:
task_id, card, api_key, ip = await queue.get()
tasks[task_id]["status"] = "processing"
async with sem:
start = time.time()
try:
raw = await asyncio.wait_for(gw["fn"](card), timeout=gw["timeout"])
elapsed = time.time() - start
result = parse_gateway_result(raw, gateway_name, elapsed)
# update avg time
prev = avg_times.get(gateway_name, elapsed)
avg_times[gateway_name] = (prev + elapsed) / 2
# reset error counter on success
if result["status"] != "error":
gateway_errors[gateway_name] = 0
else:
gateway_errors[gateway_name] = gateway_errors.get(gateway_name, 0) + 1
except asyncio.TimeoutError:
elapsed = time.time() - start
result = {"status": "error", "gateway": gateway_name, "message": "Timeout", "time": round(elapsed, 2)}
gateway_errors[gateway_name] = gateway_errors.get(gateway_name, 0) + 1
except Exception as e:
elapsed = time.time() - start
result = {"status": "error", "gateway": gateway_name, "message": f"Error: {e}", "time": round(elapsed, 2)}
gateway_errors[gateway_name] = gateway_errors.get(gateway_name, 0) + 1
log.exception(f"Gateway {gateway_name} error for task {task_id}")
# auto-disable on 5 consecutive errors
if gateway_errors.get(gateway_name, 0) >= 5:
GATEWAYS[gateway_name]["status"] = "maintenance"
log.warning(f"Gateway {gateway_name} auto-disabled after 5 consecutive errors")
# store result
tasks[task_id]["status"] = "completed"
tasks[task_id]["result"] = result
tasks[task_id]["expires"] = time.time() + 300 # 5 min TTL
# store dedup
store_dedup(card, gateway_name, result)
# log to DB
try:
await db.log_request(api_key, gateway_name, card, result["status"], elapsed, ip)
await db.increment_usage(api_key)
except Exception:
log.exception("Failed to log request to DB")
queue.task_done()
async def cleanup_loop():
"""Evict expired tasks, dedup cache entries, and rate windows every 60s."""
while True:
await asyncio.sleep(60)
now = time.time()
# expired tasks
expired = [tid for tid, t in tasks.items() if t.get("expires") and now > t["expires"]]
for tid in expired:
del tasks[tid]
# expired dedup
expired_dedup = [h for h, v in dedup_cache.items() if now > v["expires"]]
for h in expired_dedup:
del dedup_cache[h]
# expired IP blocks
expired_ips = [ip for ip, v in failed_ips.items()
if v.get("blocked_until") and now > v["blocked_until"]]
for ip in expired_ips:
del failed_ips[ip]
# ── Lifecycle ──
@asynccontextmanager
async def lifespan(app: FastAPI):
global START_TIME
START_TIME = time.time()
# init DB
await db.init_db()
log.info("Database initialized")
# init queues + semaphores + workers
for name, gw in GATEWAYS.items():
queues[name] = asyncio.Queue(maxsize=gw["max_queue"])
semaphores[name] = asyncio.Semaphore(gw["max_concurrent"])
worker = asyncio.create_task(gateway_worker(name), name=f"worker-{name}")
workers.append(worker)
gateway_errors[name] = 0
log.info(f"Gateway '{name}' ready (max_concurrent={gw['max_concurrent']}, cooldown={gw['cooldown']}s)")
# init gateway hooks
for name, gw in GATEWAYS.items():
if gw["init"]:
try:
await gw["init"]()
log.info(f"Gateway '{name}' init hook completed")
except Exception:
log.exception(f"Gateway '{name}' init hook failed")
# start cleanup loop
cleanup = asyncio.create_task(cleanup_loop(), name="cleanup")
workers.append(cleanup)
log.info(f"API server started with {len(GATEWAYS)} gateways")
yield
# shutdown
log.info("Shutting down...")
for w in workers:
w.cancel()
for name, gw in GATEWAYS.items():
if gw["shutdown"]:
try:
await gw["shutdown"]()
log.info(f"Gateway '{name}' shutdown hook completed")
except Exception:
log.exception(f"Gateway '{name}' shutdown hook failed")
log.info("Shutdown complete")
app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
@app.exception_handler(HTTPException)
async def custom_http_exception(request: Request, exc: HTTPException):
"""Return detail directly — avoids FastAPI double-wrapping dicts."""
content = exc.detail if isinstance(exc.detail, dict) else {"detail": exc.detail}
return JSONResponse(status_code=exc.status_code, content=content)
# ── Public ──
@app.get("/health")
async def health():
return {"status": "ok", "uptime": round(time.time() - START_TIME)}
# ── Client Endpoints ──
@app.post("/api/check/{gateway}")
async def submit_check(gateway: str, body: CheckRequest, request: Request):
key_record = await get_api_key(request)
api_key = key_record["api_key"]
# gateway exists?
if gateway not in GATEWAYS:
raise HTTPException(404, {"detail": f"Gateway not found: {gateway}"})
gw = GATEWAYS[gateway]
# gateway online?
if gw["status"] != "online":
raise HTTPException(503, {"detail": f"Gateway offline: {gateway}"})
# gateway access?
allowed = key_record["allowed_gateways"]
if allowed != "*" and gateway not in allowed:
raise HTTPException(403, {"detail": f"Access denied for gateway: {gateway}"})
# validate card
card = body.card.strip()
if not validate_card(card):
raise HTTPException(400, {"detail": "Invalid card format"})
# rate limit
check_rate_limit(api_key, key_record["rate_per_minute"])
# cooldown
check_cooldown(api_key, gateway, gw["cooldown"])
# dedup
cached = check_dedup(card, gateway)
if cached:
task_id = str(uuid.uuid4())[:8]
tasks[task_id] = {
"api_key": api_key, "status": "completed", "result": cached,
"gateway": gateway, "created": time.time(), "expires": time.time() + 300
}
remaining = None
if key_record["request_limit"]:
remaining = key_record["request_limit"] - key_record["requests_used"]
return {"task_id": task_id, "status": "completed", "result": cached,
"remaining": remaining, "cached": True}
# queue full?
if queues[gateway].full():
raise HTTPException(503, {"detail": "Gateway overloaded", "queue_depth": gw["max_queue"]})
# create task
task_id = str(uuid.uuid4())[:8]
ip = get_client_ip(request)
tasks[task_id] = {
"api_key": api_key, "status": "queued", "result": None,
"gateway": gateway, "created": time.time(), "expires": time.time() + 300
}
await queues[gateway].put((task_id, card, api_key, ip))
position = queues[gateway].qsize()
est_wait = round(position * avg_times.get(gateway, 10), 1)
return {"task_id": task_id, "status": "queued", "gateway": gateway,
"position": position, "estimated_wait": est_wait}
@app.get("/api/result/{task_id}")
async def poll_result(task_id: str, request: Request):
key_record = await get_api_key(request)
task = tasks.get(task_id)
if not task or task["api_key"] != key_record["api_key"]:
raise HTTPException(404, {"detail": "Task not found"})
now = time.time()
expires_in = round(task["expires"] - now) if task.get("expires") else None
if task["status"] == "queued":
position = get_queue_position(task_id, task["gateway"])
return {"task_id": task_id, "status": "queued", "position": position, "expires_in": expires_in}
if task["status"] == "processing":
return {"task_id": task_id, "status": "processing", "expires_in": expires_in}
# completed
remaining = None
rec = await db.get_key(key_record["api_key"])
if rec and rec["request_limit"]:
remaining = rec["request_limit"] - rec["requests_used"]
return {"task_id": task_id, "status": "completed", "result": task["result"],
"remaining": remaining}
@app.get("/api/usage")
async def get_usage(request: Request):
key_record = await get_api_key(request)
return {
"owner": key_record["owner"],
"requests_used": key_record["requests_used"],
"requests_limit": key_record["request_limit"],
"expires_at": key_record["expires_at"],
"gateways": key_record["allowed_gateways"],
"rate_per_minute": key_record["rate_per_minute"],
}
@app.get("/api/gateways")
async def list_gateways(request: Request):
key_record = await get_api_key(request)
allowed = key_record["allowed_gateways"]
result = {}
for name, gw in GATEWAYS.items():
if allowed == "*" or name in allowed:
result[name] = {"status": gw["status"], "type": gw["type"]}
return {"gateways": result}
@app.get("/api/cooldown")
async def get_cooldown(request: Request):
key_record = await get_api_key(request)
api_key = key_record["api_key"]
allowed = key_record["allowed_gateways"]
now = time.time()
result = {}
for name, gw in GATEWAYS.items():
if allowed != "*" and name not in allowed:
continue
last = last_request.get(api_key, {}).get(name, 0)
remaining = gw["cooldown"] - (now - last)
if remaining > 0:
result[name] = {"ready": False, "retry_after": round(remaining, 1)}
else:
result[name] = {"ready": True, "retry_after": 0}
return result
# ── Admin Endpoints ──
def require_admin(request: Request):
key = request.headers.get("x-api-key", "")
if not MASTER_KEY or key != MASTER_KEY:
raise HTTPException(401, {"detail": "Admin access denied"})
@app.post("/admin/keys")
async def admin_create_key(body: CreateKeyRequest, request: Request):
require_admin(request)
result = await db.create_key(
owner=body.owner, gateways=body.gateways, request_limit=body.request_limit,
expires_days=body.expires_days, rate_per_minute=body.rate_per_minute
)
await db.log_admin("create_key", result["api_key"], {"owner": body.owner}, get_client_ip(request))
log.info(f"Key created for '{body.owner}': {result['api_key'][:15]}...")
return result
@app.get("/admin/keys")
async def admin_list_keys(request: Request):
require_admin(request)
return await db.list_keys()
@app.get("/admin/keys/{key}")
async def admin_get_key(key: str, request: Request):
require_admin(request)
record = await db.get_key(key)
if not record:
raise HTTPException(404, {"detail": "Key not found"})
return record
@app.patch("/admin/keys/{key}")
async def admin_update_key(key: str, body: UpdateKeyRequest, request: Request):
require_admin(request)
record = await db.get_key(key)
if not record:
raise HTTPException(404, {"detail": "Key not found"})
updates = {}
changes = {}
if body.owner is not None:
updates["owner"] = body.owner
changes["owner"] = [record["owner"], body.owner]
if body.gateways is not None:
updates["allowed_gateways"] = body.gateways
changes["gateways"] = [record["allowed_gateways"], body.gateways]
if body.request_limit is not None:
updates["request_limit"] = body.request_limit
changes["request_limit"] = [record["request_limit"], body.request_limit]
if body.rate_per_minute is not None:
updates["rate_per_minute"] = body.rate_per_minute
changes["rate_per_minute"] = [record["rate_per_minute"], body.rate_per_minute]
if body.is_active is not None:
updates["is_active"] = body.is_active
changes["is_active"] = [record["is_active"], body.is_active]
if body.is_paused is not None:
updates["is_paused"] = body.is_paused
changes["is_paused"] = [record["is_paused"], body.is_paused]
if body.expires_days is not None:
from datetime import timedelta
new_exp = (datetime.now(timezone.utc) + timedelta(days=body.expires_days)).isoformat()
updates["expires_at"] = new_exp
changes["expires_at"] = [record["expires_at"], new_exp]
if not updates:
raise HTTPException(400, {"detail": "No fields to update"})
await db.update_key(key, **updates)
action = "pause_key" if body.is_paused is True else "unpause_key" if body.is_paused is False else "update_key"
await db.log_admin(action, key, changes, get_client_ip(request))
log.info(f"Key updated: {key[:15]}... — {list(changes.keys())}")
return {"detail": "Key updated", "changes": changes}
@app.delete("/admin/keys/{key}")
async def admin_delete_key(key: str, request: Request):
require_admin(request)
deleted = await db.delete_key(key)
if not deleted:
raise HTTPException(404, {"detail": "Key not found"})
await db.log_admin("revoke_key", key, None, get_client_ip(request))
log.info(f"Key revoked: {key[:15]}...")
return {"detail": "Key revoked"}
@app.get("/admin/stats")
async def admin_stats(request: Request):
require_admin(request)
stats = await db.get_stats_24h()
return {
"uptime": round(time.time() - START_TIME),
"active_tasks": sum(1 for t in tasks.values() if t["status"] in ("queued", "processing")),
"queue_depth": {name: q.qsize() for name, q in queues.items()},
"gateway_status": {name: gw["status"] for name, gw in GATEWAYS.items()},
**stats,
}
# ── Run ──
if __name__ == "__main__":
import uvicorn
host = os.getenv("API_HOST", "0.0.0.0")
port = int(os.getenv("API_PORT", "8000"))
uvicorn.run("api:app", host=host, port=port, log_level="info")