Source code for osprey.data_management.manager

"""
Data Source Manager

Unified data source management system that replaces both the registry and integration service
with a cleaner approach supporting core and application-specific data sources.
"""

import asyncio
import time
import warnings
from dataclasses import dataclass, field
from typing import Any

from osprey.utils.logger import get_logger

from .providers import DataSourceContext, DataSourceProvider
from .request import DataSourceRequest

logger = get_logger("data_manager")


[docs] @dataclass class DataRetrievalResult: """Result of data retrieval from multiple sources.""" context_data: dict[str, DataSourceContext] = field(default_factory=dict) successful_sources: list[str] = field(default_factory=list) failed_sources: list[str] = field(default_factory=list) total_sources_attempted: int = 0 retrieval_time_sec: float | None = None @property def has_data(self) -> bool: """Check if any data was successfully retrieved.""" return bool(self.context_data) @property def success_rate(self) -> float: """Calculate the success rate of data retrieval.""" if self.total_sources_attempted == 0: return 0.0 return len(self.successful_sources) / self.total_sources_attempted @property def retrieval_time_ms(self) -> float | None: """ Deprecated: Use retrieval_time_sec instead. Returns retrieval time in milliseconds for backwards compatibility. """ warnings.warn( "retrieval_time_ms is deprecated, use retrieval_time_sec instead", DeprecationWarning, stacklevel=2, ) return self.retrieval_time_sec * 1000 if self.retrieval_time_sec is not None else None
[docs] def get_summary(self) -> dict[str, Any]: """Get a summary of the retrieval results.""" return { "sources_attempted": self.total_sources_attempted, "sources_successful": len(self.successful_sources), "sources_failed": len(self.failed_sources), "success_rate": self.success_rate, "context_types_retrieved": list( {ctx.context_type for ctx in self.context_data.values()} ), "retrieval_time_sec": self.retrieval_time_sec, }
[docs] class DataSourceManager: """ Unified data source management system. Replaces both DataSourceRegistry and DataSourceIntegrationService with a cleaner architecture that supports core and application-specific data sources. """
[docs] def __init__(self): self._providers: dict[str, DataSourceProvider] = {} self._initialized = False
[docs] def register_provider(self, provider: DataSourceProvider) -> None: """ Register a data source provider. Providers are queried in registration order (framework providers first, then application providers). """ self._providers[provider.name] = provider logger.info(f"Registered data source: {provider.name}")
[docs] def get_responding_providers(self, request: DataSourceRequest) -> list[DataSourceProvider]: """ Get all providers that should respond to the current request in registration order. Args: request: Data source request with requester information Returns: List of providers that should respond in registration order (framework first, then applications) """ return [p for p in self._providers.values() if p.should_respond(request)]
[docs] async def retrieve_all_context( self, request: DataSourceRequest, timeout_seconds: float = 30.0 ) -> DataRetrievalResult: """ Retrieve context from all responding data sources. Args: request: Data source request with requester information timeout_seconds: Maximum time to wait for all data sources Returns: DataRetrievalResult containing all successfully retrieved data """ start_time = time.time() # Get responding providers in registration order providers = self.get_responding_providers(request) if not providers: logger.info("No data sources available for current context") return DataRetrievalResult(total_sources_attempted=0) logger.info(f"Retrieving context from {len(providers)} data sources") # Create retrieval tasks for all providers tasks = [] for provider in providers: task = asyncio.create_task( self._retrieve_from_provider(provider, request), name=f"retrieve_{provider.name}" ) tasks.append((provider.name, task)) # Wait for all tasks with timeout try: results = await asyncio.wait_for( asyncio.gather(*[task for _, task in tasks], return_exceptions=True), timeout=timeout_seconds, ) except TimeoutError: logger.warning(f"Data source retrieval timed out after {timeout_seconds}s") # Cancel remaining tasks for _, task in tasks: if not task.done(): task.cancel() results = [None] * len(tasks) # Treat all as failed # Process results context_data = {} successful_sources = [] failed_sources = [] empty_sources = [] for (provider_name, _), result in zip(tasks, results, strict=False): if isinstance(result, Exception): logger.warning(f"Data retrieval failed for {provider_name}: {result}") failed_sources.append(provider_name) elif result is not None: context_data[provider_name] = result successful_sources.append(provider_name) # Check if the result has meaningful content has_content = False try: # Check if data is truthy (works for UserMemories and similar types) if result.data and (not hasattr(result.data, "__bool__") or bool(result.data)): has_content = True # Also check metadata hints like entry_count elif result.metadata.get("entry_count", 0) > 0: has_content = True except Exception: # If we can't determine, assume it has content has_content = True if has_content: logger.debug(f"Successfully retrieved data from {provider_name}") else: logger.debug(f"Retrieved empty result from {provider_name} (no data available)") empty_sources.append(provider_name) else: failed_sources.append(provider_name) retrieval_time_sec = time.time() - start_time retrieval_result = DataRetrievalResult( context_data=context_data, successful_sources=successful_sources, failed_sources=failed_sources, total_sources_attempted=len(providers), retrieval_time_sec=retrieval_time_sec, ) # Log human-readable summary with better clarity sources_with_data = len([s for s in successful_sources if s not in empty_sources]) if failed_sources or empty_sources: details = [] if sources_with_data > 0: details.append(f"{sources_with_data} with data") if empty_sources: details.append(f"{len(empty_sources)} empty") if failed_sources: details.append(f"{len(failed_sources)} failed") logger.info( f"Data sources checked: {len(providers)} ({', '.join(details)}) in {retrieval_time_sec:.2f}s" ) else: logger.info( f"Retrieved data from {sources_with_data} source{'s' if sources_with_data != 1 else ''} in {retrieval_time_sec:.2f}s" ) return retrieval_result
[docs] def get_provider(self, provider_name: str) -> DataSourceProvider | None: """ Get a specific data source provider by name. Args: provider_name: Name of the data source provider to retrieve Returns: DataSourceProvider if found, None otherwise """ return self._providers.get(provider_name)
[docs] async def retrieve_from_provider( self, provider_name: str, request: DataSourceRequest ) -> DataSourceContext | None: """ Retrieve data from a specific provider by name. Args: provider_name: Name of the data source provider request: Data source request Returns: DataSourceContext if successful, None if provider not found or retrieval failed """ provider = self.get_provider(provider_name) if not provider: logger.warning(f"Data source provider '{provider_name}' not found") return None if not provider.should_respond(request): logger.debug(f"Provider '{provider_name}' chose not to respond to request") return None return await self._retrieve_from_provider(provider, request)
async def _retrieve_from_provider( self, provider: DataSourceProvider, request: DataSourceRequest ) -> DataSourceContext | None: """ Retrieve data from a single provider with error handling. Args: provider: The data source provider to retrieve from request: Data source request Returns: DataSourceContext if successful, None if failed """ try: logger.debug(f"Retrieving data from {provider.name}") return await provider.retrieve_data(request) except Exception as e: logger.warning(f"Failed to retrieve data from {provider.name}: {e}") return None
# Global manager instance _data_source_manager: DataSourceManager | None = None
[docs] def get_data_source_manager() -> DataSourceManager: """ Get the global data source manager instance. Loads all data sources from the registry system. Providers are queried in registration order (framework first, then applications). """ global _data_source_manager if _data_source_manager is None: _data_source_manager = DataSourceManager() # Load all data sources from registry try: from osprey.registry import get_registry registry = get_registry() # Get all data sources from registry registry_data_sources = registry.get_all_data_sources() for provider in registry_data_sources: _data_source_manager.register_provider(provider) logger.info(f"Loaded {len(registry_data_sources)} data sources from registry") except Exception as e: logger.warning(f"Failed to load data sources from registry: {e}") return _data_source_manager