"""Real-time OODA Loop Agent using Gemini Live API.
This module implements the SentinelAgent that maintains a bidirectional
WebSocket session for real-time multimodal interaction (voice + vision).
The agent follows the OODA loop pattern:
- Observe: Receive sensor data via WebSocket
- Orient: Process context and understand situation
- Decide: Generate actions via function calling
- Act: Execute local functions and return results
Example
-------
>>> from unsprawl.sentinel import SentinelAgent
>>>
>>> agent = SentinelAgent()
>>> await agent.run_session("Monitor traffic in sector A-7")
"""
from __future__ import annotations
import asyncio
import logging
import os
from collections.abc import Callable
from typing import Any
import google.genai as genai
from google.genai import types
logger = logging.getLogger(__name__)
# Default traffic optimization function (can be overridden)
[docs]
async def _default_optimize_traffic(sector_id: str) -> dict[str, Any]:
"""Placeholder traffic optimization function.
In production, this would call the Warp-based DifferentiableTraffic optimizer.
"""
logger.info(f"Optimizing traffic for sector: {sector_id}")
return {
"sector_id": sector_id,
"status": "optimized",
"signal_adjustments": {"north_south": 45, "east_west": 30},
"estimated_improvement": 0.15,
}
[docs]
class SentinelAgent:
"""Real-time multimodal OODA loop agent over WebSockets.
Maintains a persistent connection to Gemini Live API for bidirectional
streaming of audio, video, and text.
Parameters
----------
model : str
The Gemini Live model to use.
api_key : str | None
Google API key. If None, reads from GOOGLE_API_KEY environment variable.
response_modalities : list[str]
Output modalities. Options: ["TEXT"], ["AUDIO"], ["TEXT", "AUDIO"].
Notes
-----
- Uses `client.aio.live.connect()` for WebSocket streaming
- Supports tool calls for local function execution
- OODA = Observe-Orient-Decide-Act (military decision loop)
"""
def __init__(
self,
model: str = "gemini-live-2.5-flash-preview",
api_key: str | None = None,
response_modalities: list[str] | None = None,
) -> None:
self.model = model
self.response_modalities = response_modalities or ["TEXT"]
# Initialize the client
api_key = api_key or os.environ.get("GOOGLE_API_KEY")
if not api_key:
raise ValueError(
"API key required. Set GOOGLE_API_KEY environment variable "
"or pass api_key parameter."
)
self.client = genai.Client(api_key=api_key)
# Function registry for tool calls
self._functions: dict[str, Callable[..., Any]] = {
"optimize_traffic": _default_optimize_traffic,
}
[docs]
def register_function(self, name: str, func: Callable[..., Any]) -> None:
"""Register a function that can be called by the model.
Parameters
----------
name : str
Function name as it will appear to the model.
func : Callable
The function to execute when called.
"""
self._functions[name] = func
logger.info(f"Registered function: {name}")
[docs]
async def run_session(
self,
initial_prompt: str,
on_message: Callable[[str], None] | None = None,
max_turns: int = 10,
) -> list[str]:
"""Run a Live API session with the given initial prompt.
Parameters
----------
initial_prompt : str
The initial message to send to establish context.
on_message : Callable[[str], None] | None
Optional callback for each received text message.
max_turns : int
Maximum number of conversation turns before closing.
Returns
-------
list[str]
List of text responses received during the session.
"""
config = types.LiveConnectConfig(
response_modalities=self.response_modalities,
tools=self._build_tools(),
system_instruction=(
"You are an autonomous city infrastructure monitor. "
"Observe sensor data, detect anomalies, and take action "
"to optimize traffic flow. Use the optimize_traffic function "
"when you detect congestion or inefficiencies."
),
)
responses: list[str] = []
turn_count = 0
logger.info(f"Connecting to Live API with model: {self.model}")
async with self.client.aio.live.connect(
model=self.model,
config=config,
) as session:
# Send initial prompt
await session.send_client_content(
turns=types.Content(
role="user",
parts=[types.Part(text=initial_prompt)],
),
turn_complete=True,
)
logger.info(f"Sent initial prompt: {initial_prompt[:50]}...")
# Process responses
async for response in session.receive():
# Handle text responses
if response.text:
responses.append(response.text)
if on_message:
on_message(response.text)
logger.debug(f"Received: {response.text[:100]}...")
# Handle tool calls
if response.tool_call:
logger.info("Received tool call request")
await self._handle_tool_call(session, response.tool_call)
# Check for turn completion
if response.server_content and response.server_content.turn_complete:
turn_count += 1
logger.info(f"Turn {turn_count} complete")
if turn_count >= max_turns:
logger.info("Max turns reached, closing session")
break
logger.info(f"Session ended with {len(responses)} responses")
return responses
[docs]
async def send_observation(
self,
session: Any,
observation: str,
) -> None:
"""Send an observation to an active session.
Parameters
----------
session : AsyncSession
The active Live session.
observation : str
The observation text to send.
"""
await session.send_client_content(
turns=types.Content(
role="user",
parts=[types.Part(text=observation)],
),
turn_complete=True,
)
logger.info(f"Sent observation: {observation[:50]}...")