基于强化学习的量化交易框架 TensorTrade

简介: TensorTrade 是一个基于强化学习的开源交易算法框架。它通过环境模拟、策略训练与奖励机制,让AI在历史数据中自主学习买卖时机,构建逻辑自洽的交易策略,助力量化研究。

打开交易图表,堆上十个技术指标,然后对着屏幕发呆不知道下一步怎么操作——这场景对交易员来说太熟悉了。如果把历史数据丢给计算机,告诉它“去试错”。赚了有奖励,亏了有惩罚。让它在不断的尝试和失败中学习,最终迭代出一个不说完美、但至少能逻辑自洽的交易策略。

这就是 TensorTrade 的核心逻辑。

TensorTrade 是一个专注于利用 强化学习 (Reinforcement Learning, RL) 构建和训练交易算法的开源 Python 框架。

数据获取与特征工程

这里用

yfinance

抓取数据,配合

pandas_ta

计算技术指标。对数收益率 (Log Returns)、RSI 和 MACD 是几个比较基础的特征输入。

  pip install yfinance pandas_ta

import yfinance as yf  
import pandas_ta as ta  
import pandas as pd  

# Pick your ticker  
TICKER = "TTRD"  # TODO: change this to something real, e.g. "AAPL", "BTC-USD"  
TRAIN_START_DATE = "2021-02-09"  
TRAIN_END_DATE   = "2021-09-30"  
EVAL_START_DATE  = "2021-10-01"  
EVAL_END_DATE    = "2021-11-12"  

def build_dataset(ticker, start, end, filename):  
    # 1. Download hourly OHLCV data  
    df = yf.Ticker(ticker).history(  
        start=start,  
        end=end,  
        interval="60m"  
    )  
    # 2. Clean up  
    df = df.drop(["Dividends", "Stock Splits"], axis=1)  
    df["Volume"] = df["Volume"].astype(int)  
    # 3. Add some basic features  
    df.ta.log_return(append=True, length=16)  
    df.ta.rsi(append=True, length=14)  
    df.ta.macd(append=True, fast=12, slow=26)  
    # 4. Move Datetime from index to column  
    df = df.reset_index()  
    # 5. Save  
    df.to_csv(filename, index=False)  
    print(f"Saved {filename} with {len(df)} rows")  

build_dataset(TICKER, TRAIN_START_DATE, TRAIN_END_DATE, "training.csv")  
 build_dataset(TICKER, EVAL_START_DATE,  EVAL_END_DATE,  "evaluation.csv")

脚本跑完,目录下会生成

training.csv

evaluation.csv

。包含了 OHLCV 基础数据和几个预处理好的指标。这些就是训练 RL 模型的数据。

构建 TensorTrade 交互环境

强化学习没法直接使用CSV 文件。所以需要一个标准的交互 环境 (Environment):能够输出当前状态 (State),接收智能体的动作 (Action),并反馈奖励 (Reward)。

TensorTrade 把这个过程模块化了:

  • Instrument:定义交易标的(如 USD, TTRD)。
  • Wallet:管理资产余额。
  • Portfolio:钱包组合。
  • Stream / DataFeed:处理特征数据流。
  • reward_scheme / action_scheme:定义怎么操作,以及操作的好坏怎么评分。
  pip install tensortrade

下面是一个环境工厂函数 (Environment Factory) 的实现,设计得比较轻量,这样可以方便后续接入 Ray:

 import os  
import pandas as pd  

from tensortrade.feed.core import DataFeed, Stream  
from tensortrade.oms.instruments import Instrument  
from tensortrade.oms.exchanges import Exchange, ExchangeOptions  
from tensortrade.oms.services.execution.simulated import execute_order  
from tensortrade.oms.wallets import Wallet, Portfolio  
import tensortrade.env.default as default  

def create_env(config):  
    """  
    Build a TensorTrade environment from a CSV.  
    config needs:  
      - csv_filename  
      - window_size  
      - reward_window_size  
      - max_allowed_loss  
    """  
    # 1. Read the dataset  
    dataset = (  
        pd.read_csv(config["csv_filename"], parse_dates=["Datetime"])  
        .fillna(method="backfill")  
        .fillna(method="ffill")  
    )  
    # 2. Price stream (we'll trade on Close)  
    commission = 0.0035  # 0.35%, tweak this to your broker  
    price = Stream.source(  
        list(dataset["Close"]), dtype="float"  
    ).rename("USD-TTRD")  
    options = ExchangeOptions(commission=commission)  
    exchange = Exchange("TTSE", service=execute_order, options=options)(price)  
    # 3. Instruments and wallets  
    USD = Instrument("USD", 2, "US Dollar")  
    TTRD = Instrument("TTRD", 2, "TensorTrade Corp")  # just a label  
    cash_wallet = Wallet(exchange, 1000 * USD)  # start with $1000  
    asset_wallet = Wallet(exchange, 0 * TTRD)   # start with zero TTRD  
    portfolio = Portfolio(USD, [cash_wallet, asset_wallet])  
    # 4. Renderer feed (optional, useful for plotting later)  
    renderer_feed = DataFeed([  
        Stream.source(list(dataset["Datetime"])).rename("date"),  
        Stream.source(list(dataset["Open"]), dtype="float").rename("open"),  
        Stream.source(list(dataset["High"]), dtype="float").rename("high"),  
        Stream.source(list(dataset["Low"]), dtype="float").rename("low"),  
        Stream.source(list(dataset["Close"]), dtype="float").rename("close"),  
        Stream.source(list(dataset["Volume"]), dtype="float").rename("volume"),  
    ])  
    renderer_feed.compile()  
    # 5. Feature feed for the RL agent  
    features = []  
    # Skip Datetime (first column) and stream everything else  
    for col in dataset.columns[1:]:  
        s = Stream.source(list(dataset[col]), dtype="float").rename(col)  
        features.append(s)  
    feed = DataFeed(features)  
    feed.compile()  
    # 6. Reward and action scheme  
    reward_scheme = default.rewards.SimpleProfit(  
        window_size=config["reward_window_size"]  
    )  
    action_scheme = default.actions.BSH(  
        cash=cash_wallet,  
        asset=asset_wallet  
    )  
    # 7. Put everything together in an environment  
    env = default.create(  
        portfolio=portfolio,  
        action_scheme=action_scheme,  
        reward_scheme=reward_scheme,  
        feed=feed,  
        renderer=[],  
        renderer_feed=renderer_feed,  
        window_size=config["window_size"],  
        max_allowed_loss=config["max_allowed_loss"]  
    )  
     return env

这样“游戏”规则就已经定好了:观察最近 N 根 K 线和指标(State),决定买卖持(Action),目标是让一段时间内的利润最大化(Reward)。

基于 Ray RLlib 与 PPO 算法的模型训练

底层环境搭好,接下来让 Ray RLlib 介入处理 RL 的核心逻辑。

选用 PPO (Proximal Policy Optimization) 算法,这在连续控制和离散动作空间都有不错的表现。为了找到更优解,顺手做一个简单的超参数网格搜索:网络架构、学习率、Minibatch 大小,都跑一遍试试。

  pip install "ray[rllib]"

训练脚本如下:

 import os  
import ray  
from ray import tune  
from ray.tune.registry import register_env  

from your_module import create_env  # wherever you defined create_env  

# Some hyperparameter grids to try  
FC_SIZE = tune.grid_search([  
    [256, 256],  
    [1024],  
    [128, 64, 32],  
])  
LEARNING_RATE = tune.grid_search([  
    0.001,  
    0.0005,  
    0.00001,  
])  
MINIBATCH_SIZE = tune.grid_search([  
    5,  
    10,  
    20,  
])  
cwd = os.getcwd()  
# Register our custom environment with RLlib  
register_env("MyTrainingEnv", lambda cfg: create_env(cfg))  
env_config_training = {  
    "window_size": 14,  
    "reward_window_size": 7,  
    "max_allowed_loss": 0.10,  # cut episodes early if loss > 10%  
    "csv_filename": os.path.join(cwd, "training.csv"),  
}  
env_config_evaluation = {  
    "max_allowed_loss": 1.00,  
    "csv_filename": os.path.join(cwd, "evaluation.csv"),  
}  
ray.init(ignore_reinit_error=True)  
analysis = tune.run(  
    run_or_experiment="PPO",  
    name="MyExperiment1",  
    metric="episode_reward_mean",  
    mode="max",  
    stop={  
        "training_iteration": 5,  # small for demo, increase in real runs  
    },  
    config={  
        "env": "MyTrainingEnv",  
        "env_config": env_config_training,  
        "log_level": "WARNING",  
        "framework": "torch",     # or "tf"  
        "ignore_worker_failures": True,  
        "num_workers": 1,  
        "num_envs_per_worker": 1,  
        "num_gpus": 0,  
        "clip_rewards": True,  
        "lr": LEARNING_RATE,  
        "gamma": 0.50,            # discount factor  
        "observation_filter": "MeanStdFilter",  
        "model": {  
            "fcnet_hiddens": FC_SIZE,  
        },  
        "sgd_minibatch_size": MINIBATCH_SIZE,  
        "evaluation_interval": 1,  
        "evaluation_config": {  
            "env_config": env_config_evaluation,  
            "explore": False,     # no exploration during evaluation  
        },  
    },  
    num_samples=1,  
    keep_checkpoints_num=10,  
    checkpoint_freq=1,  
 )

这段代码本质上是在运行一场“交易机器人锦标赛”。Ray 会根据定义的参数组合并行训练多个 PPO 智能体,追踪它们的平均回合奖励,并保存下表现最好的 Checkpoint 供后续调用。

自定义奖励机制 (PBR)

默认的

SimpleProfit

奖励逻辑很简单,但实战中往往过于粗糙。我们有时需要根据具体的交易逻辑来重塑奖励函数。比如说基于持仓的奖励方案 PBR (Position-Based Reward)

  • 维护当前持仓状态(多头或空头)。
  • 监控价格变动。
  • 奖励计算 = 价格变动 × 持仓方向。

价格涨了你做多,给正反馈;价格跌了你做空,也给正反馈。反之则是惩罚。

 from tensortrade.env.default.rewards import RewardScheme  
from tensortrade.feed.core import DataFeed, Stream  

class PBR(RewardScheme):  
    """  
    Position-Based Reward (PBR)  
    Rewards the agent based on price changes and its current position.  
    """  
    registered_name = "pbr"  
    def __init__(self, price: Stream):  
        super().__init__()  
        self.position = -1  # start flat/short  
        # Price differences  
        r = Stream.sensor(price, lambda p: p.value, dtype="float").diff()  
        # Position stream  
        position = Stream.sensor(self, lambda rs: rs.position, dtype="float")  
        # Reward = price_change * position  
        reward = (r * position).fillna(0).rename("reward")  
        self.feed = DataFeed([reward])  
        self.feed.compile()  
    def on_action(self, action: int):  
        # Simple mapping: action 0 = long, everything else = short  
        self.position = 1 if action == 0 else -1  
    def get_reward(self, portfolio):  
        return self.feed.next()["reward"]  
    def reset(self):  
        self.position = -1  
         self.feed.reset()

接入也很简单,在

create_env

函数里替换掉原来的

reward_scheme

即可:

 reward_scheme = PBR(price)

这样改的好处是反馈更密集。智能体不需要等到最后平仓才知道赚没赚,每一个 step 都能收到关于“是否站对了队”的信号。

后续优化方向与建议

这套流程跑通只是个开始,想要真正可用,还有很多工作要做 比如:

  • 数据置换:代码里的 TTRD 只是个占位符,换成真实的标的(股票、Crypto、指数)。
  • 特征工程:RSI 和 MACD 只是抛砖引玉,试试 ATR、布林带,或者引入更长时间周期的特征。
  • 参数调优gamma(折扣因子)、window_size(观测窗口)对策略风格影响巨大,值得花时间去扫参。
  • 基准测试:这一步最关键。把你训练出来的 RL 策略和 Buy & Hold(买入持有)比一比,甚至和随机策略比一比。如果跑不过随机策略,那就得从头检查了。

最后别忘了,我们只是研究,所以不要直接实盘。模型在训练集上大杀四方是常态,能通过样本外测试和模拟盘 (Paper Trading) 的考验才是真本事。

https://avoid.overfit.cn/post/8c9e08414e514c73ab3aefd694294f79
作者:CodeBun

目录
相关文章
|
4月前
|
弹性计算 人工智能 对象存储
阿里云服务器最新优惠价格表:含 ECS、轻量、GPU 配置及收费标准
阿里云服务器多少钱?阿里云服务器优惠价格表:涵盖轻量应用服务器、ECS 云服务器、GPU 服务器等主流产品,低至 38 元1年、99元和199元收费,部分配置升级至 200M 带宽且不限流量,无论是个人开发者、中小企业还是大型企业,都能找到适配需求的高性价比方案。以下是整理的阿里云最新优惠价格及配置详情::轻量应用服务器200M峰值带宽68元1年(秒杀38元),ECS云服务器2核2G3M带宽99元一年、2核4G、5M带宽、80G系统盘优惠价格199元一年,4核16G服务器10M带宽89元1个月,8核32G服务器10M固定带宽160元一个月,阿里云香港轻量服务器200M带宽25元个月起。方便大
|
4月前
|
机器学习/深度学习 人工智能 运维
构建AI智能体:六十二、金融风控系统:基于信息熵和KL散度的异常交易检测
本文介绍了一种基于信息论的智能金融风控系统,通过KL散度、信息增益和熵等核心概念构建欺诈检测框架。系统首先生成模拟金融交易数据,区分正常与欺诈交易;然后计算各特征的数据熵和KL散度,量化分布差异;再训练随机森林模型进行预测,并创新性地结合概率和不确定性计算风险得分。实验表明,设备风险是最强欺诈指标,系统AUC达1.0,能有效识别典型欺诈模式(大额、深夜、高频交易)。该方法将抽象信息论转化为实用解决方案,在保持高性能的同时增强了模型可解释性,为智能风控提供了量化分析框架。
426 3
|
4月前
|
Kubernetes Cloud Native Nacos
MCP 网关实战:基于 Higress + Nacos 的零代码工具扩展方案
本文会围绕如何基于 Higress 和 Nacos 的 docker 镜像在 K8s 集群上进行分角色部署。
781 75
|
4月前
|
人工智能 自然语言处理 Java
AI工具选择困难症?Spring AI帮你省掉64%的令牌费用
你的AI助手有50+个工具但每次对话前就烧掉55000个令牌?就像带着全套工具箱去拧个螺丝一样浪费!Spring AI的工具搜索模式让AI按需发现工具,实现34-64%的令牌节省,告别工具选择困难症和账单焦虑。#Spring AI #工具优化 #令牌节省 #AI开发
618 2
|
3月前
|
机器学习/深度学习 Java
为什么所有主流LLM都使用SwiGLU?
本文解析现代大语言模型为何用SwiGLU替代ReLU。SwiGLU结合Swish与门控机制,通过乘法交互实现特征组合,增强表达能力;其平滑性与非饱和梯度利于优化,相较ReLU更具优势。
254 8
为什么所有主流LLM都使用SwiGLU?
|
4月前
|
人工智能 运维 监控
开源项目分享:Gitee热榜项目 2025年12月第二周 周榜
本文档汇总Gitee本周热门开源项目,涵盖Fay、JeeLowCode等明星项目,结合AI与低代码趋势,深入分析技术融合与场景创新,助力开发者把握前沿动态。
497 2
|
3月前
|
机器学习/深度学习 测试技术 数据中心
九坤量化开源IQuest-Coder-V1,代码大模型进入“流式”训练时代
2026年首日,九坤创始团队成立的至知创新研究院开源IQuest-Coder-V1系列代码大模型,涵盖7B至40B参数,支持128K上下文与GQA架构,提供Base、Instruct、Thinking及Loop版本。采用创新Code-Flow训练范式,模拟代码演化全过程,提升复杂任务推理能力,在SWE-Bench、LiveCodeBench等基准领先。全阶段checkpoint开放,支持本地部署与微调,助力研究与应用落地。
1271 2
|
机器学习/深度学习 算法 PyTorch
昇腾910-PyTorch 实现 ResNet50图像分类
本实验基于PyTorch,在昇腾平台上使用ResNet50对CIFAR10数据集进行图像分类训练。内容涵盖ResNet50的网络架构、残差模块分析及训练代码详解。通过端到端的实战讲解,帮助读者理解如何在深度学习中应用ResNet50模型,并实现高效的图像分类任务。实验包括数据预处理、模型搭建、训练与测试等环节,旨在提升模型的准确率和训练效率。
791 54
|
4月前
|
机器学习/深度学习 人工智能 弹性计算
阿里云服务器租用价格:最新包年包月、按量付费活动价格参考
阿里云服务器租用价格又更新了,租用阿里云轻量应用服务器一年价格是38元,经济型e实例2核2G3M带宽 40G ESSD Entry云盘特惠价99元1年,通用算力型u1实例2核4G5M带宽80G ESSD Entry云盘特惠价199元1年。通用算力型u2i实例4核8G1170.26元1年起。本文为大家展示本次价格更新之后,阿里云服务器的最新租用价格,包含经济型e、通用算力型u2i/u2a、计算型c9i/c9a、通用型g9i/g9a、内存型r9i/r9a等不同实例规格的活动价格,以供大家对比和选择参考。
858 13
|
4月前
|
弹性计算
2 核 4G云服务器多少钱?2 核 4G阿里云 ECS 云服务器计算型 c9i 实例测评
阿里云 ECS 计算型 c9i 实例(2 核 4G)凭借高性能配置,成为众多用户的热门选择。本文整理了该实例的核心参数、多地域租用价格、计费方式及折扣政策,帮助你快速了解选购要点。