Skip to content

Commit cd571ea

Browse files
committed
feat: Add TTL-based cache expiration and fix deprecations
Changes: - Add TTL support to in-memory cache with configurable expiration times - Prices: 1 hour TTL - Financial metrics: 24 hours TTL (less volatile) - Company news: 30 minutes TTL (time-sensitive) - Add cache cleanup for expired entries - Add clear() method to reset all caches - Fix deprecated @app.on_event('startup') -> lifespan context manager - Fix deprecated declarative_base import location - Reduce black line-length from 420 to 120 for readability - Add comprehensive test suite for cache module (24 tests) Tests cover: - CacheEntry creation and expiration - Get/set operations for all cache types - TTL enforcement per data type - Data deduplication on merge - Expired entry cleanup - Multiple ticker isolation
1 parent 8935283 commit cd571ea

File tree

5 files changed

+512
-67
lines changed

5 files changed

+512
-67
lines changed

app/backend/database/connection.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from sqlalchemy import create_engine
2-
from sqlalchemy.ext.declarative import declarative_base
3-
from sqlalchemy.orm import sessionmaker
2+
from sqlalchemy.orm import sessionmaker, declarative_base
43
import os
54
from pathlib import Path
65

app/backend/main.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from contextlib import asynccontextmanager
12
from fastapi import FastAPI
23
from fastapi.middleware.cors import CORSMiddleware
34
import logging
@@ -12,26 +13,11 @@
1213
logging.basicConfig(level=logging.INFO)
1314
logger = logging.getLogger(__name__)
1415

15-
app = FastAPI(title="AI Hedge Fund API", description="Backend API for AI Hedge Fund", version="0.1.0")
1616

17-
# Initialize database tables (this is safe to run multiple times)
18-
Base.metadata.create_all(bind=engine)
19-
20-
# Configure CORS
21-
app.add_middleware(
22-
CORSMiddleware,
23-
allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"], # Frontend URLs
24-
allow_credentials=True,
25-
allow_methods=["*"],
26-
allow_headers=["*"],
27-
)
28-
29-
# Include all routes
30-
app.include_router(api_router)
31-
32-
@app.on_event("startup")
33-
async def startup_event():
34-
"""Startup event to check Ollama availability."""
17+
@asynccontextmanager
18+
async def lifespan(app: FastAPI):
19+
"""Lifespan context manager for startup and shutdown events."""
20+
# Startup
3521
try:
3622
logger.info("Checking Ollama availability...")
3723
status = await ollama_service.check_ollama_status()
@@ -53,3 +39,31 @@ async def startup_event():
5339
except Exception as e:
5440
logger.warning(f"Could not check Ollama status: {e}")
5541
logger.info("ℹ Ollama integration is available if you install it later")
42+
43+
yield
44+
45+
# Shutdown (cleanup if needed)
46+
logger.info("Shutting down AI Hedge Fund API...")
47+
48+
49+
app = FastAPI(
50+
title="AI Hedge Fund API",
51+
description="Backend API for AI Hedge Fund",
52+
version="0.1.0",
53+
lifespan=lifespan
54+
)
55+
56+
# Initialize database tables (this is safe to run multiple times)
57+
Base.metadata.create_all(bind=engine)
58+
59+
# Configure CORS
60+
app.add_middleware(
61+
CORSMiddleware,
62+
allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"], # Frontend URLs
63+
allow_credentials=True,
64+
allow_methods=["*"],
65+
allow_headers=["*"],
66+
)
67+
68+
# Include all routes
69+
app.include_router(api_router)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ requires = ["poetry-core"]
4848
build-backend = "poetry.core.masonry.api"
4949

5050
[tool.black]
51-
line-length = 420
51+
line-length = 120
5252
target-version = ['py311']
5353
include = '\.pyi?$'
5454

src/data/cache.py

Lines changed: 111 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,34 @@
1+
import time
2+
from typing import Any
3+
4+
5+
class CacheEntry:
6+
"""A cache entry with TTL support."""
7+
8+
def __init__(self, data: list[dict[str, Any]], ttl_seconds: int = 3600):
9+
self.data = data
10+
self.created_at = time.time()
11+
self.ttl_seconds = ttl_seconds
12+
13+
def is_expired(self) -> bool:
14+
"""Check if the cache entry has expired."""
15+
return time.time() - self.created_at > self.ttl_seconds
16+
17+
118
class Cache:
2-
"""In-memory cache for API responses."""
19+
"""In-memory cache for API responses with TTL support."""
20+
21+
# Default TTL: 1 hour for most data, 24 hours for less volatile data
22+
DEFAULT_TTL = 3600 # 1 hour
23+
METRICS_TTL = 86400 # 24 hours (metrics don't change frequently)
24+
NEWS_TTL = 1800 # 30 minutes (news is more time-sensitive)
325

426
def __init__(self):
5-
self._prices_cache: dict[str, list[dict[str, any]]] = {}
6-
self._financial_metrics_cache: dict[str, list[dict[str, any]]] = {}
7-
self._line_items_cache: dict[str, list[dict[str, any]]] = {}
8-
self._insider_trades_cache: dict[str, list[dict[str, any]]] = {}
9-
self._company_news_cache: dict[str, list[dict[str, any]]] = {}
27+
self._prices_cache: dict[str, CacheEntry] = {}
28+
self._financial_metrics_cache: dict[str, CacheEntry] = {}
29+
self._line_items_cache: dict[str, CacheEntry] = {}
30+
self._insider_trades_cache: dict[str, CacheEntry] = {}
31+
self._company_news_cache: dict[str, CacheEntry] = {}
1032

1133
def _merge_data(self, existing: list[dict] | None, new_data: list[dict], key_field: str) -> list[dict]:
1234
"""Merge existing and new data, avoiding duplicates based on a key field."""
@@ -21,45 +43,89 @@ def _merge_data(self, existing: list[dict] | None, new_data: list[dict], key_fie
2143
merged.extend([item for item in new_data if item[key_field] not in existing_keys])
2244
return merged
2345

24-
def get_prices(self, ticker: str) -> list[dict[str, any]] | None:
25-
"""Get cached price data if available."""
26-
return self._prices_cache.get(ticker)
27-
28-
def set_prices(self, ticker: str, data: list[dict[str, any]]):
29-
"""Append new price data to cache."""
30-
self._prices_cache[ticker] = self._merge_data(self._prices_cache.get(ticker), data, key_field="time")
31-
32-
def get_financial_metrics(self, ticker: str) -> list[dict[str, any]]:
33-
"""Get cached financial metrics if available."""
34-
return self._financial_metrics_cache.get(ticker)
35-
36-
def set_financial_metrics(self, ticker: str, data: list[dict[str, any]]):
37-
"""Append new financial metrics to cache."""
38-
self._financial_metrics_cache[ticker] = self._merge_data(self._financial_metrics_cache.get(ticker), data, key_field="report_period")
39-
40-
def get_line_items(self, ticker: str) -> list[dict[str, any]] | None:
41-
"""Get cached line items if available."""
42-
return self._line_items_cache.get(ticker)
43-
44-
def set_line_items(self, ticker: str, data: list[dict[str, any]]):
45-
"""Append new line items to cache."""
46-
self._line_items_cache[ticker] = self._merge_data(self._line_items_cache.get(ticker), data, key_field="report_period")
47-
48-
def get_insider_trades(self, ticker: str) -> list[dict[str, any]] | None:
49-
"""Get cached insider trades if available."""
50-
return self._insider_trades_cache.get(ticker)
51-
52-
def set_insider_trades(self, ticker: str, data: list[dict[str, any]]):
53-
"""Append new insider trades to cache."""
54-
self._insider_trades_cache[ticker] = self._merge_data(self._insider_trades_cache.get(ticker), data, key_field="filing_date") # Could also use transaction_date if preferred
55-
56-
def get_company_news(self, ticker: str) -> list[dict[str, any]] | None:
57-
"""Get cached company news if available."""
58-
return self._company_news_cache.get(ticker)
59-
60-
def set_company_news(self, ticker: str, data: list[dict[str, any]]):
61-
"""Append new company news to cache."""
62-
self._company_news_cache[ticker] = self._merge_data(self._company_news_cache.get(ticker), data, key_field="date")
46+
def _cleanup_expired(self, cache_dict: dict[str, CacheEntry]) -> None:
47+
"""Remove expired entries from a cache dictionary."""
48+
expired_keys = [key for key, entry in cache_dict.items() if entry.is_expired()]
49+
for key in expired_keys:
50+
del cache_dict[key]
51+
52+
def get_prices(self, ticker: str) -> list[dict[str, Any]] | None:
53+
"""Get cached price data if available and not expired."""
54+
self._cleanup_expired(self._prices_cache)
55+
entry = self._prices_cache.get(ticker)
56+
if entry and not entry.is_expired():
57+
return entry.data
58+
return None
59+
60+
def set_prices(self, ticker: str, data: list[dict[str, Any]]):
61+
"""Append new price data to cache with TTL."""
62+
existing_data = self.get_prices(ticker)
63+
merged = self._merge_data(existing_data, data, key_field="time")
64+
self._prices_cache[ticker] = CacheEntry(merged, ttl_seconds=self.DEFAULT_TTL)
65+
66+
def get_financial_metrics(self, ticker: str) -> list[dict[str, Any]] | None:
67+
"""Get cached financial metrics if available and not expired."""
68+
self._cleanup_expired(self._financial_metrics_cache)
69+
entry = self._financial_metrics_cache.get(ticker)
70+
if entry and not entry.is_expired():
71+
return entry.data
72+
return None
73+
74+
def set_financial_metrics(self, ticker: str, data: list[dict[str, Any]]):
75+
"""Append new financial metrics to cache with TTL."""
76+
existing_data = self.get_financial_metrics(ticker)
77+
merged = self._merge_data(existing_data, data, key_field="report_period")
78+
self._financial_metrics_cache[ticker] = CacheEntry(merged, ttl_seconds=self.METRICS_TTL)
79+
80+
def get_line_items(self, ticker: str) -> list[dict[str, Any]] | None:
81+
"""Get cached line items if available and not expired."""
82+
self._cleanup_expired(self._line_items_cache)
83+
entry = self._line_items_cache.get(ticker)
84+
if entry and not entry.is_expired():
85+
return entry.data
86+
return None
87+
88+
def set_line_items(self, ticker: str, data: list[dict[str, Any]]):
89+
"""Append new line items to cache with TTL."""
90+
existing_data = self.get_line_items(ticker)
91+
merged = self._merge_data(existing_data, data, key_field="report_period")
92+
self._line_items_cache[ticker] = CacheEntry(merged, ttl_seconds=self.METRICS_TTL)
93+
94+
def get_insider_trades(self, ticker: str) -> list[dict[str, Any]] | None:
95+
"""Get cached insider trades if available and not expired."""
96+
self._cleanup_expired(self._insider_trades_cache)
97+
entry = self._insider_trades_cache.get(ticker)
98+
if entry and not entry.is_expired():
99+
return entry.data
100+
return None
101+
102+
def set_insider_trades(self, ticker: str, data: list[dict[str, Any]]):
103+
"""Append new insider trades to cache with TTL."""
104+
existing_data = self.get_insider_trades(ticker)
105+
merged = self._merge_data(existing_data, data, key_field="filing_date")
106+
self._insider_trades_cache[ticker] = CacheEntry(merged, ttl_seconds=self.DEFAULT_TTL)
107+
108+
def get_company_news(self, ticker: str) -> list[dict[str, Any]] | None:
109+
"""Get cached company news if available and not expired."""
110+
self._cleanup_expired(self._company_news_cache)
111+
entry = self._company_news_cache.get(ticker)
112+
if entry and not entry.is_expired():
113+
return entry.data
114+
return None
115+
116+
def set_company_news(self, ticker: str, data: list[dict[str, Any]]):
117+
"""Append new company news to cache with TTL."""
118+
existing_data = self.get_company_news(ticker)
119+
merged = self._merge_data(existing_data, data, key_field="date")
120+
self._company_news_cache[ticker] = CacheEntry(merged, ttl_seconds=self.NEWS_TTL)
121+
122+
def clear(self) -> None:
123+
"""Clear all caches."""
124+
self._prices_cache.clear()
125+
self._financial_metrics_cache.clear()
126+
self._line_items_cache.clear()
127+
self._insider_trades_cache.clear()
128+
self._company_news_cache.clear()
63129

64130

65131
# Global cache instance

0 commit comments

Comments
 (0)