Source code for unsprawl.sentinel

"""Real-time OODA Loop Agent using Gemini Live API.

This module implements the SentinelAgent that maintains a bidirectional
WebSocket session for real-time multimodal interaction (voice + vision).

The agent follows the OODA loop pattern:
- Observe: Receive sensor data via WebSocket
- Orient: Process context and understand situation
- Decide: Generate actions via function calling
- Act: Execute local functions and return results

Example
-------
>>> from unsprawl.sentinel import SentinelAgent
>>>
>>> agent = SentinelAgent()
>>> await agent.run_session("Monitor traffic in sector A-7")
"""

from __future__ import annotations

import asyncio
import logging
import os
from collections.abc import Callable
from typing import Any

import google.genai as genai
from google.genai import types

logger = logging.getLogger(__name__)


# Default traffic optimization function (can be overridden)
[docs] async def _default_optimize_traffic(sector_id: str) -> dict[str, Any]: """Placeholder traffic optimization function. In production, this would call the Warp-based DifferentiableTraffic optimizer. """ logger.info(f"Optimizing traffic for sector: {sector_id}") return { "sector_id": sector_id, "status": "optimized", "signal_adjustments": {"north_south": 45, "east_west": 30}, "estimated_improvement": 0.15, }
[docs] class SentinelAgent: """Real-time multimodal OODA loop agent over WebSockets. Maintains a persistent connection to Gemini Live API for bidirectional streaming of audio, video, and text. Parameters ---------- model : str The Gemini Live model to use. api_key : str | None Google API key. If None, reads from GOOGLE_API_KEY environment variable. response_modalities : list[str] Output modalities. Options: ["TEXT"], ["AUDIO"], ["TEXT", "AUDIO"]. Notes ----- - Uses `client.aio.live.connect()` for WebSocket streaming - Supports tool calls for local function execution - OODA = Observe-Orient-Decide-Act (military decision loop) """ def __init__( self, model: str = "gemini-live-2.5-flash-preview", api_key: str | None = None, response_modalities: list[str] | None = None, ) -> None: self.model = model self.response_modalities = response_modalities or ["TEXT"] # Initialize the client api_key = api_key or os.environ.get("GOOGLE_API_KEY") if not api_key: raise ValueError( "API key required. Set GOOGLE_API_KEY environment variable " "or pass api_key parameter." ) self.client = genai.Client(api_key=api_key) # Function registry for tool calls self._functions: dict[str, Callable[..., Any]] = { "optimize_traffic": _default_optimize_traffic, }
[docs] def register_function(self, name: str, func: Callable[..., Any]) -> None: """Register a function that can be called by the model. Parameters ---------- name : str Function name as it will appear to the model. func : Callable The function to execute when called. """ self._functions[name] = func logger.info(f"Registered function: {name}")
[docs] def _build_tools(self) -> list[types.Tool]: """Build the tools configuration for the Live session.""" function_declarations = [ types.FunctionDeclaration( name="optimize_traffic", description=( "Optimize traffic signal timings for a specific sector. " "Call this when congestion is detected or traffic flow " "needs improvement." ), parameters=types.Schema( type="OBJECT", properties={ "sector_id": types.Schema( type="STRING", description="The sector identifier (e.g., 'A-7', 'B-12')", ), }, required=["sector_id"], ), ), ] return [types.Tool(function_declarations=function_declarations)]
[docs] async def _handle_tool_call( self, session: Any, tool_call: types.LiveServerToolCall, ) -> None: """Handle a tool call from the model and send the response back. Parameters ---------- session : AsyncSession The active Live session. tool_call : LiveServerToolCall The tool call request from the model. """ function_responses = [] for func_call in tool_call.function_calls: func_name = func_call.name func_args = func_call.args or {} call_id = func_call.id logger.info(f"Tool call: {func_name}({func_args})") if func_name in self._functions: try: func = self._functions[func_name] # Handle both sync and async functions if asyncio.iscoroutinefunction(func): result = await func(**func_args) else: result = func(**func_args) function_responses.append( types.FunctionResponse( name=func_name, response={"result": result}, id=call_id, ) ) logger.info(f"Function {func_name} completed: {result}") except Exception as e: logger.error(f"Function {func_name} failed: {e}") function_responses.append( types.FunctionResponse( name=func_name, response={"error": str(e)}, id=call_id, ) ) else: logger.warning(f"Unknown function: {func_name}") function_responses.append( types.FunctionResponse( name=func_name, response={"error": f"Unknown function: {func_name}"}, id=call_id, ) ) # Send tool responses back await session.send_tool_response(function_responses=function_responses)
[docs] async def run_session( self, initial_prompt: str, on_message: Callable[[str], None] | None = None, max_turns: int = 10, ) -> list[str]: """Run a Live API session with the given initial prompt. Parameters ---------- initial_prompt : str The initial message to send to establish context. on_message : Callable[[str], None] | None Optional callback for each received text message. max_turns : int Maximum number of conversation turns before closing. Returns ------- list[str] List of text responses received during the session. """ config = types.LiveConnectConfig( response_modalities=self.response_modalities, tools=self._build_tools(), system_instruction=( "You are an autonomous city infrastructure monitor. " "Observe sensor data, detect anomalies, and take action " "to optimize traffic flow. Use the optimize_traffic function " "when you detect congestion or inefficiencies." ), ) responses: list[str] = [] turn_count = 0 logger.info(f"Connecting to Live API with model: {self.model}") async with self.client.aio.live.connect( model=self.model, config=config, ) as session: # Send initial prompt await session.send_client_content( turns=types.Content( role="user", parts=[types.Part(text=initial_prompt)], ), turn_complete=True, ) logger.info(f"Sent initial prompt: {initial_prompt[:50]}...") # Process responses async for response in session.receive(): # Handle text responses if response.text: responses.append(response.text) if on_message: on_message(response.text) logger.debug(f"Received: {response.text[:100]}...") # Handle tool calls if response.tool_call: logger.info("Received tool call request") await self._handle_tool_call(session, response.tool_call) # Check for turn completion if response.server_content and response.server_content.turn_complete: turn_count += 1 logger.info(f"Turn {turn_count} complete") if turn_count >= max_turns: logger.info("Max turns reached, closing session") break logger.info(f"Session ended with {len(responses)} responses") return responses
[docs] async def send_observation( self, session: Any, observation: str, ) -> None: """Send an observation to an active session. Parameters ---------- session : AsyncSession The active Live session. observation : str The observation text to send. """ await session.send_client_content( turns=types.Content( role="user", parts=[types.Part(text=observation)], ), turn_complete=True, ) logger.info(f"Sent observation: {observation[:50]}...")