Files
Stock-trading-programming/scripts/build_external_training_dataset.py
T

171 lines
5.9 KiB
Python
Raw Normal View History

"""
Build training rows from external minute bars.
This generates synthetic candidate rows from minute bars, not actual bot trades.
Rows are useful for pretraining movement/holding-period models.
"""
import argparse
import csv
from collections import defaultdict
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
MINUTE_ROOT = ROOT / "data" / "external" / "minute"
DAILY_ROOT = ROOT / "data" / "external" / "daily"
DEFAULT_OUT = ROOT / "data" / "external_training_dataset.csv"
CHECKPOINTS = (1, 3, 5, 10)
def _read_csv(path: Path) -> list[dict]:
with path.open("r", encoding="utf-8-sig", newline="") as f:
return list(csv.DictReader(f))
def _num(value, default=0.0):
try:
return float(str(value).replace(",", ""))
except (TypeError, ValueError):
return default
def _load_daily_amounts() -> dict[tuple[str, str], dict]:
result = {}
for file in DAILY_ROOT.glob("*/stocks.csv"):
rows = _read_csv(file)
for row in rows:
date = str(row.get("date") or file.parent.name)
ticker = str(row.get("ticker") or row.get("티커") or "")
if ticker:
result[(date, ticker)] = row
return result
def _future_metrics(rows: list[dict], idx: int, entry_price: float):
metrics = {}
highs = []
lows = []
for minutes in CHECKPOINTS:
j = idx + minutes
if j >= len(rows):
metrics[f"price_{minutes}m"] = ""
metrics[f"ret_{minutes}m"] = ""
metrics[f"mfe_{minutes}m"] = ""
metrics[f"mae_{minutes}m"] = ""
continue
window = rows[idx + 1:j + 1]
highs.extend(_num(r["high"]) for r in window)
lows.extend(_num(r["low"]) for r in window)
close = _num(rows[j]["close"])
high = max(highs) if highs else entry_price
low = min(lows) if lows else entry_price
metrics[f"price_{minutes}m"] = close
metrics[f"ret_{minutes}m"] = (close - entry_price) / entry_price * 100 if entry_price else 0
metrics[f"mfe_{minutes}m"] = (high - entry_price) / entry_price * 100 if entry_price else 0
metrics[f"mae_{minutes}m"] = (low - entry_price) / entry_price * 100 if entry_price else 0
return metrics
def _rows_for_file(path: Path, daily: dict, k: float, breakout_only: bool):
rows = _read_csv(path)
rows = [r for r in rows if r.get("time") and _num(r.get("close")) > 0]
rows.sort(key=lambda r: (r.get("date", ""), r.get("time", "")))
if len(rows) < 20:
return []
by_date = defaultdict(list)
for row in rows:
by_date[row["date"]].append(row)
out = []
prev_by_date = {}
for date in sorted(by_date):
day_rows = by_date[date]
ticker = day_rows[0]["ticker"]
daily_row = daily.get((date, ticker), {})
prev_high = _num(daily_row.get("high"))
prev_low = _num(daily_row.get("low"))
prev_amount = _num(daily_row.get("amount"))
if not prev_high or not prev_low:
prev = prev_by_date.get(ticker)
if prev:
prev_high = prev["high"]
prev_low = prev["low"]
prev_amount = prev["amount"]
today_open = _num(day_rows[0]["open"]) or _num(day_rows[0]["close"])
target = today_open + (prev_high - prev_low) * k if prev_high and prev_low else 0
crossed = False
for idx, row in enumerate(day_rows[:-max(CHECKPOINTS)]):
tm = row["time"][:4]
if tm < "0905" or tm >= "1400":
continue
close = _num(row["close"])
if breakout_only:
if not target or close < target or crossed:
continue
crossed = True
metrics = _future_metrics(day_rows, idx, close)
label_win = 1 if _num(metrics.get("ret_10m")) > 0 else 0
label_stop_loss = 1 if _num(metrics.get("mae_10m")) <= -2.0 else 0
out.append({
"source": "external_minute",
"date": date,
"ticker": ticker,
"entry_time": row["time"],
"current_price": close,
"entry_price": close,
"target_price": target,
"today_open": today_open,
"prev_high": prev_high,
"prev_low": prev_low,
"prev_amount": prev_amount,
"volume": row.get("volume", ""),
**metrics,
"label_win": label_win,
"label_stop_loss": label_stop_loss,
})
prev_by_date[ticker] = {
"high": max(_num(r["high"]) for r in day_rows),
"low": min(_num(r["low"]) for r in day_rows),
"amount": sum(_num(r["close"]) * _num(r["volume"]) for r in day_rows),
}
return out
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--minute-root", default=str(MINUTE_ROOT))
parser.add_argument("--out", default=str(DEFAULT_OUT))
parser.add_argument("--k", type=float, default=0.5)
parser.add_argument("--all-minutes", action="store_true", help="Use every eligible minute, not only first breakout.")
args = parser.parse_args()
daily = _load_daily_amounts()
all_rows = []
for path in Path(args.minute_root).glob("*/*.csv"):
all_rows.extend(_rows_for_file(path, daily, args.k, breakout_only=not args.all_minutes))
out_path = Path(args.out)
if not out_path.is_absolute():
out_path = ROOT / out_path
out_path.parent.mkdir(parents=True, exist_ok=True)
fieldnames = sorted({key for row in all_rows for key in row.keys()})
with out_path.open("w", encoding="utf-8-sig", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(all_rows)
print(f"external dataset rows={len(all_rows)} -> {out_path}")
if __name__ == "__main__":
main()