"""
Data Source Manager
Unified data source management system that replaces both the registry and integration service
with a cleaner approach supporting core and application-specific data sources.
"""
import asyncio
import time
import warnings
from dataclasses import dataclass, field
from typing import Any
from osprey.utils.logger import get_logger
from .providers import DataSourceContext, DataSourceProvider
from .request import DataSourceRequest
logger = get_logger("data_manager")
[docs]
@dataclass
class DataRetrievalResult:
"""Result of data retrieval from multiple sources."""
context_data: dict[str, DataSourceContext] = field(default_factory=dict)
successful_sources: list[str] = field(default_factory=list)
failed_sources: list[str] = field(default_factory=list)
total_sources_attempted: int = 0
retrieval_time_sec: float | None = None
@property
def has_data(self) -> bool:
"""Check if any data was successfully retrieved."""
return bool(self.context_data)
@property
def success_rate(self) -> float:
"""Calculate the success rate of data retrieval."""
if self.total_sources_attempted == 0:
return 0.0
return len(self.successful_sources) / self.total_sources_attempted
@property
def retrieval_time_ms(self) -> float | None:
"""
Deprecated: Use retrieval_time_sec instead.
Returns retrieval time in milliseconds for backwards compatibility.
"""
warnings.warn(
"retrieval_time_ms is deprecated, use retrieval_time_sec instead",
DeprecationWarning,
stacklevel=2,
)
return self.retrieval_time_sec * 1000 if self.retrieval_time_sec is not None else None
[docs]
def get_summary(self) -> dict[str, Any]:
"""Get a summary of the retrieval results."""
return {
"sources_attempted": self.total_sources_attempted,
"sources_successful": len(self.successful_sources),
"sources_failed": len(self.failed_sources),
"success_rate": self.success_rate,
"context_types_retrieved": list(
{ctx.context_type for ctx in self.context_data.values()}
),
"retrieval_time_sec": self.retrieval_time_sec,
}
[docs]
class DataSourceManager:
"""
Unified data source management system.
Replaces both DataSourceRegistry and DataSourceIntegrationService with a
cleaner architecture that supports core and application-specific data sources.
"""
[docs]
def __init__(self):
self._providers: dict[str, DataSourceProvider] = {}
self._initialized = False
[docs]
def register_provider(self, provider: DataSourceProvider) -> None:
"""
Register a data source provider.
Providers are queried in registration order (framework providers first,
then application providers).
"""
self._providers[provider.name] = provider
logger.info(f"Registered data source: {provider.name}")
[docs]
def get_responding_providers(self, request: DataSourceRequest) -> list[DataSourceProvider]:
"""
Get all providers that should respond to the current request in registration order.
Args:
request: Data source request with requester information
Returns:
List of providers that should respond in registration order (framework first, then applications)
"""
return [p for p in self._providers.values() if p.should_respond(request)]
[docs]
async def retrieve_all_context(
self, request: DataSourceRequest, timeout_seconds: float = 30.0
) -> DataRetrievalResult:
"""
Retrieve context from all responding data sources.
Args:
request: Data source request with requester information
timeout_seconds: Maximum time to wait for all data sources
Returns:
DataRetrievalResult containing all successfully retrieved data
"""
start_time = time.time()
# Get responding providers in registration order
providers = self.get_responding_providers(request)
if not providers:
logger.info("No data sources available for current context")
return DataRetrievalResult(total_sources_attempted=0)
logger.info(f"Retrieving context from {len(providers)} data sources")
# Create retrieval tasks for all providers
tasks = []
for provider in providers:
task = asyncio.create_task(
self._retrieve_from_provider(provider, request), name=f"retrieve_{provider.name}"
)
tasks.append((provider.name, task))
# Wait for all tasks with timeout
try:
results = await asyncio.wait_for(
asyncio.gather(*[task for _, task in tasks], return_exceptions=True),
timeout=timeout_seconds,
)
except TimeoutError:
logger.warning(f"Data source retrieval timed out after {timeout_seconds}s")
# Cancel remaining tasks
for _, task in tasks:
if not task.done():
task.cancel()
results = [None] * len(tasks) # Treat all as failed
# Process results
context_data = {}
successful_sources = []
failed_sources = []
empty_sources = []
for (provider_name, _), result in zip(tasks, results, strict=False):
if isinstance(result, Exception):
logger.warning(f"Data retrieval failed for {provider_name}: {result}")
failed_sources.append(provider_name)
elif result is not None:
context_data[provider_name] = result
successful_sources.append(provider_name)
# Check if the result has meaningful content
has_content = False
try:
# Check if data is truthy (works for UserMemories and similar types)
if result.data and (not hasattr(result.data, "__bool__") or bool(result.data)):
has_content = True
# Also check metadata hints like entry_count
elif result.metadata.get("entry_count", 0) > 0:
has_content = True
except Exception:
# If we can't determine, assume it has content
has_content = True
if has_content:
logger.debug(f"Successfully retrieved data from {provider_name}")
else:
logger.debug(f"Retrieved empty result from {provider_name} (no data available)")
empty_sources.append(provider_name)
else:
failed_sources.append(provider_name)
retrieval_time_sec = time.time() - start_time
retrieval_result = DataRetrievalResult(
context_data=context_data,
successful_sources=successful_sources,
failed_sources=failed_sources,
total_sources_attempted=len(providers),
retrieval_time_sec=retrieval_time_sec,
)
# Log human-readable summary with better clarity
sources_with_data = len([s for s in successful_sources if s not in empty_sources])
if failed_sources or empty_sources:
details = []
if sources_with_data > 0:
details.append(f"{sources_with_data} with data")
if empty_sources:
details.append(f"{len(empty_sources)} empty")
if failed_sources:
details.append(f"{len(failed_sources)} failed")
logger.info(
f"Data sources checked: {len(providers)} ({', '.join(details)}) in {retrieval_time_sec:.2f}s"
)
else:
logger.info(
f"Retrieved data from {sources_with_data} source{'s' if sources_with_data != 1 else ''} in {retrieval_time_sec:.2f}s"
)
return retrieval_result
[docs]
def get_provider(self, provider_name: str) -> DataSourceProvider | None:
"""
Get a specific data source provider by name.
Args:
provider_name: Name of the data source provider to retrieve
Returns:
DataSourceProvider if found, None otherwise
"""
return self._providers.get(provider_name)
[docs]
async def retrieve_from_provider(
self, provider_name: str, request: DataSourceRequest
) -> DataSourceContext | None:
"""
Retrieve data from a specific provider by name.
Args:
provider_name: Name of the data source provider
request: Data source request
Returns:
DataSourceContext if successful, None if provider not found or retrieval failed
"""
provider = self.get_provider(provider_name)
if not provider:
logger.warning(f"Data source provider '{provider_name}' not found")
return None
if not provider.should_respond(request):
logger.debug(f"Provider '{provider_name}' chose not to respond to request")
return None
return await self._retrieve_from_provider(provider, request)
async def _retrieve_from_provider(
self, provider: DataSourceProvider, request: DataSourceRequest
) -> DataSourceContext | None:
"""
Retrieve data from a single provider with error handling.
Args:
provider: The data source provider to retrieve from
request: Data source request
Returns:
DataSourceContext if successful, None if failed
"""
try:
logger.debug(f"Retrieving data from {provider.name}")
return await provider.retrieve_data(request)
except Exception as e:
logger.warning(f"Failed to retrieve data from {provider.name}: {e}")
return None
# Global manager instance
_data_source_manager: DataSourceManager | None = None
[docs]
def get_data_source_manager() -> DataSourceManager:
"""
Get the global data source manager instance.
Loads all data sources from the registry system. Providers are queried
in registration order (framework first, then applications).
"""
global _data_source_manager
if _data_source_manager is None:
_data_source_manager = DataSourceManager()
# Load all data sources from registry
try:
from osprey.registry import get_registry
registry = get_registry()
# Get all data sources from registry
registry_data_sources = registry.get_all_data_sources()
for provider in registry_data_sources:
_data_source_manager.register_provider(provider)
logger.info(f"Loaded {len(registry_data_sources)} data sources from registry")
except Exception as e:
logger.warning(f"Failed to load data sources from registry: {e}")
return _data_source_manager