407 lines
15 KiB
Python
407 lines
15 KiB
Python
# This file is part of Pythagoras.
|
|
#
|
|
# Pythagoras is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# Pythagoras is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with Pythagoras. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
from fastapi import FastAPI, File, Request, WebSocket, WebSocketDisconnect, BackgroundTasks, HTTPException
|
|
from fastapi.responses import JSONResponse, FileResponse
|
|
import logging
|
|
import uvicorn
|
|
from typing import Dict, List, Any
|
|
from dataclasses import dataclass
|
|
import json
|
|
import httpx
|
|
import asyncio
|
|
import struct
|
|
import os
|
|
from pathlib import Path
|
|
from asyncio import Lock
|
|
|
|
# Some useful variables
|
|
PEEHAITCHPEA_ENDPOINT = "http://peehaitchpea.libre-liberec.cz/api.php?cmd=getselectedmessage"
|
|
PYTHAGORAS_PORT = 8000
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger("pythagoras")
|
|
|
|
# Configure FastAPI app and initial values for variables
|
|
app = FastAPI(title="Pythagoras", description="A proxy service handling HTTP and WebSocket connections")
|
|
app.state.auto_polling = False
|
|
app.state.polling_rate = 5
|
|
app.state.enable_libretranslate = True
|
|
app.state.client_state = "idle"
|
|
app.state.latest_message = None
|
|
|
|
# Define the media directory
|
|
MEDIA_DIR = Path("./media")
|
|
|
|
# Create the media directory if it doesn't exist
|
|
os.makedirs(MEDIA_DIR, exist_ok=True)
|
|
|
|
# Store for connected websocket clients
|
|
|
|
class WSConnection:
|
|
inner: WebSocket
|
|
lock: Lock
|
|
|
|
def __init__(self, sock: WebSocket):
|
|
self.inner = sock
|
|
self.lock = Lock()
|
|
|
|
async def accept(self):
|
|
async with self.lock:
|
|
await self.inner.accept()
|
|
|
|
async def send_bytes(self, data: bytes):
|
|
async with self.lock:
|
|
await self.inner.send_bytes(data)
|
|
|
|
async def send_text(self, data: str):
|
|
async with self.lock:
|
|
await self.inner.send_text(data)
|
|
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
self.active_connections: List[WSConnection] = []
|
|
|
|
async def connect(self, websocket: WSConnection):
|
|
await websocket.accept()
|
|
self.active_connections.append(websocket)
|
|
await setscreen_single(websocket, app.state.client_state) # Send the latest client state to any new client
|
|
await self.singlecast_binary({"type": "selectedmessage", "message": app.state.latest_message}, websocket) # Broadcast latest message to all clients
|
|
logger.info(f"WebSocket client connected. Total connections: {len(self.active_connections)}")
|
|
|
|
def disconnect(self, websocket: WSConnection):
|
|
self.active_connections.remove(websocket)
|
|
logger.info(f"WebSocket client disconnected. Total connections: {len(self.active_connections)}")
|
|
|
|
def json_to_binary(self, json_str: str):
|
|
json_bytes = json_str.encode('utf-8')
|
|
json_length = len(json_bytes)
|
|
|
|
# 4-byte unsigned integer (uint32)
|
|
length_bytes = struct.pack('!I', json_length)
|
|
if len(length_bytes) != 4:
|
|
raise Exception("invalid packed length")
|
|
|
|
return length_bytes + json_bytes
|
|
|
|
|
|
async def broadcast(self, message: str):
|
|
for connection in self.active_connections:
|
|
await connection.send_text(message)
|
|
|
|
async def broadcast_binary(self, data_dict: dict):
|
|
"""
|
|
Broadcasts a message to all connections as binary data.
|
|
"""
|
|
if not self.active_connections:
|
|
return
|
|
|
|
json_str = json.dumps(data_dict)
|
|
message_bytes = self.json_to_binary(json_str)
|
|
|
|
for connection in self.active_connections:
|
|
try:
|
|
await connection.send_bytes(message_bytes)
|
|
logger.debug(f"Sent binary message ({len(message_bytes)} bytes) to a client")
|
|
except Exception as e:
|
|
logger.error(f"Failed to send binary message: {str(e)}")
|
|
|
|
async def singlecast_binary(self, data_dict: dict, ws: WSConnection):
|
|
"""
|
|
I love code duplication
|
|
"""
|
|
|
|
json_str = json.dumps(data_dict)
|
|
message_bytes = self.json_to_binary(json_str)
|
|
|
|
try:
|
|
await ws.send_bytes(message_bytes)
|
|
logger.debug(f"Sent binary message ({len(message_bytes)} bytes) to a client")
|
|
except Exception as e:
|
|
logger.error(f"Failed to send binary message: {str(e)}")
|
|
|
|
manager = ConnectionManager()
|
|
|
|
@app.get("/presentation/")
|
|
async def presentation_index(_: Request):
|
|
return FileResponse(status_code=200, path="static/index.html", media_type="text/html")
|
|
|
|
@app.get("/presentation/script.js")
|
|
async def presentation_script(_: Request):
|
|
return FileResponse(
|
|
status_code=200, path="static/script.js", media_type="text/javascript"
|
|
)
|
|
|
|
@app.get("/presentation/style.css")
|
|
async def presentation_style(_: Request):
|
|
return FileResponse(status_code=200, path="static/style.css", media_type="text/css")
|
|
|
|
@app.get("/presentation/files/{file_path:path}")
|
|
async def presentation_file(file_path: str):
|
|
return FileResponse(status_code=200, path=f"static/{file_path}")
|
|
|
|
# Endpoints
|
|
@app.post("/control")
|
|
async def control_endpoint(request: Request):
|
|
"""Endpoint for control data."""
|
|
try:
|
|
data = await request.json()
|
|
logger.info(f"Received control data: {data}")
|
|
|
|
if data['command'] == "getselectedmessage":
|
|
message = await fetch_selected_message()
|
|
message = None if message == "" else message
|
|
await process_selected_message(message)
|
|
logger.info(f"Received new selected message initiated from control: {message}")
|
|
|
|
elif data['command'] == "setautopolling" and 'state' in data:
|
|
new_state = data['state']
|
|
app.state.auto_polling = new_state
|
|
logger.info(f"Polling command issued, changing auto-polling to {new_state}")
|
|
|
|
elif data['command'] == "setlibretranslate" and 'state' in data:
|
|
new_state = data['state']
|
|
app.state.auto_polling = new_state
|
|
logger.info(f"LibreTranslate command issued, changing state to {new_state}")
|
|
|
|
elif data['command'] == "autopollingrate" and 'rate' in data:
|
|
new_rate = data['rate']
|
|
app.state.polling_rate = new_rate
|
|
logger.info(f"Auto-polling rate change requested: {new_rate} seconds")
|
|
|
|
elif data['command'] == "playvideo" and 'filename' in data:
|
|
filename = data['filename']
|
|
subtitles = None
|
|
seconds_from_start = 0
|
|
if 'subtitles' in data:
|
|
subtitles = data['subtitles']
|
|
if 'secondsfromstart' in data:
|
|
seconds_from_start = data['secondsfromstart']
|
|
await playvideo(filename, subtitles, seconds_from_start)
|
|
logger.info(f"Video playback requested: {filename} with subtitles {subtitles} starting {seconds_from_start} seconds from start.")
|
|
|
|
elif data['command'] == "seekvideo" and 'timestamp' in data:
|
|
timestamp = data['timestamp']
|
|
await seekvideo(timestamp)
|
|
logger.info(f"Seeking of the currently playing video requested to {timestamp} seconds")
|
|
|
|
|
|
|
|
elif data['command'] == "setscreen_main":
|
|
await setscreen("main")
|
|
app.state.client_state = "main"
|
|
logger.info(f"Setting the client screen to main view.")
|
|
|
|
elif data['command'] == "setscreen_video":
|
|
await setscreen("video")
|
|
app.state.client_state = "video"
|
|
logger.info(f"Setting the client screen to video playback.")
|
|
|
|
elif data['command'] == "setscreen_idle":
|
|
await setscreen("idle")
|
|
app.state.client_state = "idle"
|
|
logger.info(f"Setting the client screen to idle.")
|
|
|
|
|
|
return JSONResponse(
|
|
status_code=200,
|
|
content={"status": "success", "message": "Control data received"}
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error processing control data: {str(e)}")
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={"status": "error", "message": f"Failed to process request."}
|
|
)
|
|
|
|
@app.post("/subtitles/update_current")
|
|
async def subtitles_update_current_endpoint(request: Request):
|
|
"""Endpoint for subtitle data - updating the current sentence as it comes."""
|
|
return await process_subtitles(request, "update_current")
|
|
|
|
@app.post("/subtitles/submit_sentence")
|
|
async def subtitles_submit_sentence_endpoint(request: Request):
|
|
"""Endpoint for subtitle data - submitting the final version of a sentence."""
|
|
return await process_subtitles(request, "submit_sentence")
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
"""WebSocket endpoint for real-time communication."""
|
|
conn = WSConnection(websocket)
|
|
await manager.connect(conn)
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_text()
|
|
logger.info(f"Received message from WebSocket: {data}")
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(conn)
|
|
except Exception as e:
|
|
logger.error(f"WebSocket error: {str(e)}")
|
|
manager.disconnect(conn)
|
|
|
|
@app.get("/media/{file_path:path}")
|
|
async def get_media(file_path: str):
|
|
"""
|
|
Serve media files from the media directory.
|
|
"""
|
|
full_path = MEDIA_DIR / file_path
|
|
absolute_path = full_path.resolve()
|
|
|
|
# Security check: Make sure the path is within the media directory
|
|
if not str(absolute_path).startswith(str(MEDIA_DIR.resolve())):
|
|
logger.warning(f"Attempted directory traversal: {file_path}")
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
|
|
# Check if the file exists
|
|
if not absolute_path.is_file():
|
|
logger.warning(f"File not found: {absolute_path}")
|
|
raise HTTPException(status_code=404, detail="File not found")
|
|
|
|
logger.info(f"Serving media file: {file_path}")
|
|
|
|
return FileResponse(path=absolute_path, filename=absolute_path.name)
|
|
|
|
|
|
|
|
# Functions
|
|
async def process_subtitles(request: Request, sub_type: str):
|
|
try:
|
|
text_content = await request.body()
|
|
en_subtitle_text = text_content.decode("utf-8")
|
|
logger.info(f"Received subtitle text: {en_subtitle_text}, request type: {sub_type}")
|
|
|
|
if manager.active_connections:
|
|
await manager.broadcast_binary({"type": f"subtitle_en_{sub_type}", "text": en_subtitle_text})
|
|
|
|
if app.state.enable_libretranslate and sub_type == "submit_sentence":
|
|
cs_subtitle_text = await translate_to_cs_libre(en_subtitle_text)
|
|
await manager.broadcast_binary({"type": f"subtitle_cs_{sub_type}", "text": cs_subtitle_text})
|
|
|
|
return JSONResponse(
|
|
status_code=200,
|
|
content={"status": "success", "message": "Subtitle text received"}
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error processing {sub_type} subtitle data: {str(e)}")
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={"status": "error", "message": f"Failed to process request."}
|
|
)
|
|
|
|
async def translate_to_cs_libre(text: str):
|
|
"""
|
|
Translates the provided text from English to Czech using LibreTranslate.
|
|
"""
|
|
if not text:
|
|
return text
|
|
|
|
try:
|
|
url = "http://localhost:5000/translate"
|
|
|
|
payload = {
|
|
"q": text,
|
|
"source": "en",
|
|
"target": "cs",
|
|
"format": "text"
|
|
}
|
|
|
|
timeout = httpx.Timeout(10.0)
|
|
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
response = await client.post(url, json=payload)
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
translated_text = result.get("translatedText", text)
|
|
logger.info(f"Successfully translated text to Czech")
|
|
return translated_text
|
|
else:
|
|
logger.error(f"Translation API error: {response.status_code}, {response.text}")
|
|
return text
|
|
|
|
except Exception as e:
|
|
logger.error(f"Translation error: {str(e)}")
|
|
return text
|
|
|
|
async def setscreen_single(ws: WSConnection, screen: str):
|
|
return await manager.singlecast_binary({"type": "setscreen", "screen": screen}, ws)
|
|
|
|
async def setscreen(screen: str):
|
|
return await manager.broadcast_binary({"type": "setscreen", "screen": screen})
|
|
|
|
async def playvideo(filename: str, subtitles: str | None, seconds_from_start: int):
|
|
return await manager.broadcast_binary({"type": "playvideo", "filename": filename, "subtitles": subtitles, "seconds_from_start": seconds_from_start})
|
|
|
|
async def seekvideo(timestamp: int):
|
|
return await manager.broadcast_binary({"type": "seekvideo", "timestamp": timestamp})
|
|
|
|
|
|
|
|
|
|
async def fetch_selected_message():
|
|
"""
|
|
Fetches a selected message from the specified endpoint.
|
|
Returns the message as a string or None if no message is available.
|
|
"""
|
|
auth = httpx.BasicAuth(username="stallman", password="gnu")
|
|
try:
|
|
async with httpx.AsyncClient(auth=auth) as client:
|
|
response = await client.get(PEEHAITCHPEA_ENDPOINT)
|
|
if response.status_code == 200:
|
|
message = response.text.strip()
|
|
if message:
|
|
logger.info(f"Received selected message: {message}")
|
|
return message
|
|
else:
|
|
return None
|
|
else:
|
|
logger.warning(f"Failed to fetch message. Status code: {response.status_code}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error fetching selected message: {str(e)}")
|
|
return None
|
|
|
|
async def process_selected_message(message: str | None):
|
|
"""
|
|
Processes the selected message and saves it to cache.
|
|
"""
|
|
logger.info(f"Processing message: {message}")
|
|
app.state.latest_message = message
|
|
if manager.active_connections:
|
|
await manager.broadcast_binary({"type": "selectedmessage", "message": message})
|
|
|
|
async def periodic_message_check():
|
|
"""Periodically checks for new messages."""
|
|
while True:
|
|
if app.state.auto_polling is True:
|
|
logger.info("Automatically polling message...")
|
|
message = await fetch_selected_message()
|
|
await process_selected_message(message)
|
|
await asyncio.sleep(app.state.polling_rate)
|
|
|
|
# Startup tasks setup
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""Start background tasks when the application starts."""
|
|
asyncio.create_task(periodic_message_check())
|
|
|
|
# Main function and app entry point
|
|
if __name__ == "__main__":
|
|
uvicorn.run("main:app", host="0.0.0.0", port=PYTHAGORAS_PORT, reload=True)
|