Files
Stock-trading-programming/scripts/build_external_training_dataset.py
T

184 lines
6.3 KiB
Python
Raw Normal View History

"""
Build training rows from external minute bars.
This generates synthetic candidate rows from minute bars, not actual bot trades.
Rows are useful for pretraining movement/holding-period models.
"""
import argparse
import csv
from collections import defaultdict
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
MINUTE_ROOT = ROOT / "data" / "external" / "minute"
DAILY_ROOT = ROOT / "data" / "external" / "daily"
DEFAULT_OUT = ROOT / "data" / "external_training_dataset.csv"
CHECKPOINTS = (1, 3, 5, 10)
def _read_csv(path: Path) -> list[dict]:
with path.open("r", encoding="utf-8-sig", newline="") as f:
return list(csv.DictReader(f))
def _num(value, default=0.0):
try:
return float(str(value).replace(",", ""))
except (TypeError, ValueError):
return default
def _load_daily_amounts() -> dict[str, list[dict]]:
by_ticker = defaultdict(dict)
for file in DAILY_ROOT.glob("*/stocks.csv"):
rows = _read_csv(file)
for row in rows:
date = str(row.get("date") or file.parent.name)
ticker = str(row.get("ticker") or row.get("티커") or "")
if ticker:
row["date"] = date
by_ticker[ticker][date] = row
return {
ticker: [rows[date] for date in sorted(rows)]
for ticker, rows in by_ticker.items()
}
def _previous_daily_row(daily: dict[str, list[dict]], date: str, ticker: str) -> dict:
previous = {}
for row in daily.get(ticker, []):
if str(row.get("date", "")) >= date:
break
previous = row
return previous
def _future_metrics(rows: list[dict], idx: int, entry_price: float):
metrics = {}
highs = []
lows = []
for minutes in CHECKPOINTS:
j = idx + minutes
if j >= len(rows):
metrics[f"price_{minutes}m"] = ""
metrics[f"ret_{minutes}m"] = ""
metrics[f"mfe_{minutes}m"] = ""
metrics[f"mae_{minutes}m"] = ""
continue
window = rows[idx + 1:j + 1]
highs.extend(_num(r["high"]) for r in window)
lows.extend(_num(r["low"]) for r in window)
close = _num(rows[j]["close"])
high = max(highs) if highs else entry_price
low = min(lows) if lows else entry_price
metrics[f"price_{minutes}m"] = close
metrics[f"ret_{minutes}m"] = (close - entry_price) / entry_price * 100 if entry_price else 0
metrics[f"mfe_{minutes}m"] = (high - entry_price) / entry_price * 100 if entry_price else 0
metrics[f"mae_{minutes}m"] = (low - entry_price) / entry_price * 100 if entry_price else 0
return metrics
def _rows_for_file(path: Path, daily: dict, k: float, breakout_only: bool):
rows = _read_csv(path)
rows = [r for r in rows if r.get("time") and _num(r.get("close")) > 0]
rows.sort(key=lambda r: (r.get("date", ""), r.get("time", "")))
if len(rows) < 20:
return []
by_date = defaultdict(list)
for row in rows:
by_date[row["date"]].append(row)
out = []
prev_by_date = {}
for date in sorted(by_date):
day_rows = by_date[date]
ticker = day_rows[0]["ticker"]
daily_row = _previous_daily_row(daily, date, ticker)
prev_high = _num(daily_row.get("high"))
prev_low = _num(daily_row.get("low"))
prev_amount = _num(daily_row.get("amount"))
if not prev_high or not prev_low:
prev = prev_by_date.get(ticker)
if prev:
prev_high = prev["high"]
prev_low = prev["low"]
prev_amount = prev["amount"]
today_open = _num(day_rows[0]["open"]) or _num(day_rows[0]["close"])
target = today_open + (prev_high - prev_low) * k if prev_high and prev_low else 0
crossed = False
for idx, row in enumerate(day_rows[:-max(CHECKPOINTS)]):
tm = row["time"][:4]
if tm < "0905" or tm >= "1400":
continue
close = _num(row["close"])
if breakout_only:
if not target or close < target or crossed:
continue
crossed = True
metrics = _future_metrics(day_rows, idx, close)
label_win = 1 if _num(metrics.get("ret_10m")) > 0 else 0
label_stop_loss = 1 if _num(metrics.get("mae_10m")) <= -2.0 else 0
out.append({
"source": "external_minute",
"date": date,
"ticker": ticker,
"entry_time": row["time"],
"current_price": close,
"entry_price": close,
"target_price": target,
"today_open": today_open,
"prev_high": prev_high,
"prev_low": prev_low,
"prev_amount": prev_amount,
"volume": row.get("volume", ""),
**metrics,
"label_win": label_win,
"label_stop_loss": label_stop_loss,
})
prev_by_date[ticker] = {
"high": max(_num(r["high"]) for r in day_rows),
"low": min(_num(r["low"]) for r in day_rows),
"amount": sum(_num(r["close"]) * _num(r["volume"]) for r in day_rows),
}
return out
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--minute-root", default=str(MINUTE_ROOT))
parser.add_argument("--out", default=str(DEFAULT_OUT))
parser.add_argument("--k", type=float, default=0.5)
parser.add_argument("--all-minutes", action="store_true", help="Use every eligible minute, not only first breakout.")
args = parser.parse_args()
daily = _load_daily_amounts()
all_rows = []
for path in Path(args.minute_root).glob("*/*.csv"):
all_rows.extend(_rows_for_file(path, daily, args.k, breakout_only=not args.all_minutes))
out_path = Path(args.out)
if not out_path.is_absolute():
out_path = ROOT / out_path
out_path.parent.mkdir(parents=True, exist_ok=True)
fieldnames = sorted({key for row in all_rows for key in row.keys()})
with out_path.open("w", encoding="utf-8-sig", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(all_rows)
print(f"external dataset rows={len(all_rows)} -> {out_path}")
if __name__ == "__main__":
main()