backtrader.reports.charts 源代码

#!/usr/bin/env python
"""
Report-specific chart generator.

Generates static charts for reports, distinct from interactive plotting.
"""

import base64
import io

try:
    import matplotlib

    matplotlib.use("Agg")  # Use non-interactive backend
    import matplotlib.dates as mdates
    import matplotlib.pyplot as plt

    MATPLOTLIB_AVAILABLE = True
except ImportError:
    MATPLOTLIB_AVAILABLE = False

try:
    import pandas as pd

    PANDAS_AVAILABLE = True
except ImportError:
    PANDAS_AVAILABLE = False


[文档] class ReportChart: """Report-specific chart generator. Generates static charts for reports, including: - Equity curve chart (with buy-and-hold comparison line) - Return bars chart (automatic period detection) - Drawdown area chart Attributes: figsize: Default chart size dpi: Chart resolution """
[文档] def __init__(self, figsize=(10, 3), dpi=100): """Initialize the chart generator. Args: figsize: Chart size (width, height) dpi: Chart resolution """ self.figsize = figsize self.dpi = dpi self._figures = []
[文档] def plot_equity_curve( self, dates, values, benchmark_dates=None, benchmark_values=None, title="Equity Curve" ): """Plot equity curve chart. Args: dates: List of dates values: List of equity values benchmark_dates: List of benchmark dates (optional) benchmark_values: List of benchmark values (optional, e.g., buy-and-hold) title: Chart title Returns: matplotlib.figure.Figure or None """ if not MATPLOTLIB_AVAILABLE or not dates or not values: return None fig, ax = plt.subplots(1, 1, figsize=self.figsize, dpi=self.dpi) # Normalize to 100 start_value = values[0] if values[0] != 0 else 1 normalized_values = [100 * v / start_value for v in values] # Plot equity curve ax.plot(dates, normalized_values, label="Strategy", linewidth=1.5, color="#3498DB") # Plot buy-and-hold comparison line if benchmark_dates and benchmark_values: ax.plot( benchmark_dates, benchmark_values, label="Buy & Hold", linewidth=1, color="gray", linestyle="--", ) # Plot baseline (100) ax.axhline(y=100, color="gray", linestyle=":", linewidth=0.8, alpha=0.7) ax.set_ylabel("Net Asset Value (start=100)") ax.set_title(title) ax.legend(loc="upper left") ax.grid(True, alpha=0.3) # Format x-axis dates if len(dates) > 0: ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d")) plt.xticks(rotation=45) plt.tight_layout() self._figures.append(fig) return fig
[文档] def plot_return_bars(self, dates, values, period="auto", title=None): """Plot return bars chart. Args: dates: List of dates values: List of equity values period: Period ('auto', 'daily', 'weekly', 'monthly', 'yearly') title: Chart title Returns: matplotlib.figure.Figure or None """ if not MATPLOTLIB_AVAILABLE or not PANDAS_AVAILABLE: return None if not dates or not values: return None # Create Series series = pd.Series(data=values, index=pd.to_datetime(dates)) # Auto-detect period if period == "auto": period_name, period_code = self._get_periodicity(dates) else: period_map = { "daily": ("Daily", "D"), "weekly": ("Weekly", "W"), "monthly": ("Monthly", "ME"), "yearly": ("Yearly", "YE"), } period_name, period_code = period_map.get(period, ("Daily", "D")) # Resample and calculate returns try: resampled = series.resample(period_code).last() returns = 100 * resampled.pct_change().dropna() except Exception: return None if len(returns) == 0: return None fig, ax = plt.subplots(1, 1, figsize=self.figsize, dpi=self.dpi) # Set colors based on positive/negative values colors = ["green" if r > 0 else "red" for r in returns.values] # Plot bar chart x_labels = [ d.strftime("%Y-%m-%d") if hasattr(d, "strftime") else str(d) for d in returns.index ] ax.bar(range(len(returns)), returns.values, color=colors, alpha=0.7) # Set x-axis labels if len(x_labels) <= 20: ax.set_xticks(range(len(returns))) ax.set_xticklabels(x_labels, rotation=45, ha="right") else: # Show only partial labels when too many step = len(x_labels) // 10 ax.set_xticks(range(0, len(returns), step)) ax.set_xticklabels(x_labels[::step], rotation=45, ha="right") ax.axhline(y=0, color="gray", linestyle="-", linewidth=0.5) ax.set_ylabel("Return (%)") ax.set_title(title or f"{period_name} Returns") ax.grid(True, alpha=0.3, axis="y") plt.tight_layout() self._figures.append(fig) return fig
[文档] def plot_drawdown(self, dates, values, title="Drawdown"): """Plot drawdown area chart. Args: dates: List of dates values: List of equity values title: Chart title Returns: matplotlib.figure.Figure or None """ if not MATPLOTLIB_AVAILABLE or not dates or not values: return None # Calculate drawdown running_max = values[0] drawdowns = [] for v in values: if v > running_max: running_max = v dd = (v - running_max) / running_max * 100 if running_max != 0 else 0 drawdowns.append(dd) fig, ax = plt.subplots(1, 1, figsize=self.figsize, dpi=self.dpi) # Plot drawdown area ax.fill_between(dates, drawdowns, 0, alpha=0.3, color="red", label="Drawdown") ax.plot(dates, drawdowns, color="red", linewidth=1) ax.set_ylabel("Drawdown (%)") ax.set_title(title) ax.grid(True, alpha=0.3) # Show maximum drawdown max_dd = min(drawdowns) max_dd_idx = drawdowns.index(max_dd) ax.annotate( f"Max: {max_dd:.2f}%", xy=(dates[max_dd_idx], max_dd), xytext=(10, 10), textcoords="offset points", fontsize=9, color="red", ) # Format x-axis dates if len(dates) > 0: ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d")) plt.xticks(rotation=45) plt.tight_layout() self._figures.append(fig) return fig
def _get_periodicity(self, dates): """Intelligently determine the best display period. Args: dates: List of dates Returns: tuple: (period name, period code) """ if not dates or len(dates) < 2: return ("Daily", "D") try: start_date = dates[0] end_date = dates[-1] if hasattr(start_date, "days"): time_interval_days = (end_date - start_date).days else: from datetime import datetime if isinstance(start_date, datetime): time_interval_days = (end_date - start_date).days else: time_interval_days = 30 # Default if time_interval_days > 5 * 365.25: return ("Yearly", "YE") elif time_interval_days > 365.25: return ("Monthly", "ME") elif time_interval_days > 50: return ("Weekly", "W") elif time_interval_days > 5: return ("Daily", "D") elif time_interval_days > 0.5: return ("Hourly", "H") else: return ("Per Minute", "T") except Exception: return ("Daily", "D")
[文档] def save_to_file(self, fig, filename, format="png"): """Save chart to file. Args: fig: matplotlib figure object filename: Output filename format: Image format ('png', 'jpg', 'svg', 'pdf') """ if fig is None: return fig.savefig(filename, format=format, dpi=self.dpi, bbox_inches="tight")
[文档] def to_base64(self, fig, format="png"): """Convert chart to base64 encoding. Args: fig: matplotlib figure object format: Image format Returns: str: base64 encoded image data """ if fig is None: return "" buf = io.BytesIO() fig.savefig(buf, format=format, dpi=self.dpi, bbox_inches="tight") buf.seek(0) img_base64 = base64.b64encode(buf.read()).decode("utf-8") buf.close() return f"data:image/{format};base64,{img_base64}"
[文档] def close_all(self): """Close all charts and release memory.""" for fig in self._figures: plt.close(fig) self._figures = []