Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions rqalpha/data/data_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# 详细的授权流程,请联系 public@ricequant.com 获取。

from datetime import datetime, date
from typing import Union, List, Sequence, Optional, Tuple, Iterable, Dict, Callable
from typing import Union, List, Sequence, Optional, Tuple

import numpy as np
import pandas as pd
Expand All @@ -36,9 +36,11 @@
from rqalpha.core.execution_context import ExecutionContext
from rqalpha.utils.typing import DateLike
from rqalpha.utils.exception import InstrumentNotFound
from rqalpha.data.base_data_source.storages import FuturesTradingParameters

from .instruments_mixin import InstrumentsMixin


class DataProxy(TradingDatesMixin, InstrumentsMixin):
def __init__(self, data_source: AbstractDataSource, price_board: AbstractPriceBoard):
self._data_source = data_source
Expand Down Expand Up @@ -152,18 +154,18 @@ def _get_settlement(self, instrument, dt):

def get_prev_settlement(self, order_book_id, dt):
instrument = self.instruments(order_book_id)
if instrument.type not in (INSTRUMENT_TYPE.FUTURE, INSTRUMENT_TYPE.OPTION):
if instrument.type not in (INSTRUMENT_TYPE.FUTURE, INSTRUMENT_TYPE.OPTION, INSTRUMENT_TYPE.SPOT):
return np.nan
return self._get_prev_settlement(instrument, dt)

def get_settlement(self, instrument: Instrument, dt: datetime) -> float:
if instrument.type != INSTRUMENT_TYPE.FUTURE:
if instrument.type not in (INSTRUMENT_TYPE.FUTURE, INSTRUMENT_TYPE.OPTION, INSTRUMENT_TYPE.SPOT):
raise LookupError("'{}', instrument_type={}".format(instrument.order_book_id, instrument.type))
return self._get_settlement(instrument, dt)

def get_settle_price(self, order_book_id, trading_dt: datetime):
instrument = self.get_active_instrument(order_book_id, trading_dt)
if instrument.type != 'Future':
if instrument.type not in (INSTRUMENT_TYPE.FUTURE, INSTRUMENT_TYPE.OPTION, INSTRUMENT_TYPE.SPOT):
return np.nan
return self._data_source.get_settle_price(instrument, trading_dt)

Expand Down Expand Up @@ -260,8 +262,7 @@ def tick_fields_for(ins):
def available_data_range(self, frequency):
return self._data_source.available_data_range(frequency)

def get_futures_trading_parameters(self, order_book_id, dt):
# type: (str, datetime.date) -> FuturesTradingParameters
def get_futures_trading_parameters(self, order_book_id: str, dt: datetime.date) -> FuturesTradingParameters:
instrument = self.instruments(order_book_id)
return self._data_source.get_futures_trading_parameters(instrument, dt)

Expand All @@ -288,8 +289,7 @@ def get_tick_size(self, order_book_id):
def get_last_price(self, order_book_id: str) -> float:
return float(self._price_board.get_last_price(order_book_id))

def get_future_contracts(self, underlying, date):
# type: (str, DateLike) -> List[str]
def get_future_contracts(self, underlying: str, date: DateLike) -> List[str]:
return sorted(i.order_book_id for i in self.all_instruments(
[INSTRUMENT_TYPE.FUTURE], date
) if i.underlying_symbol == underlying and not Instrument.is_future_continuous_contract(i.order_book_id))
Expand Down
66 changes: 13 additions & 53 deletions rqalpha/mod/rqalpha_mod_sys_accounts/api/api_stock.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from rqalpha.utils.exception import RQInvalidArgument
from rqalpha.utils.i18n import gettext as _
from rqalpha.utils.logger import user_system_log
from rqalpha.mod.rqalpha_mod_sys_accounts.trade_utils import estimate_transaction_cost_calculator, round_order_quantity, get_amount_from_value
from .order_target_portfolio import order_target_portfolio_smart

# 使用Decimal 解决浮点数运算精度问题
Expand All @@ -59,10 +60,6 @@
export_as_api(sector_code, name='sector_code')


KSH_MIN_AMOUNT = 200
BJSE_MIN_AMOUNT = 100


def _get_account_position(order_book_id: str):
try:
account = Environment.get_instance().portfolio.accounts[DEFAULT_ACCOUNT_TYPE.STOCK]
Expand All @@ -74,21 +71,6 @@ def _get_account_position(order_book_id: str):
return account, position


def _round_order_quantity(ins, quantity, method: Callable = int) -> int:
if ins.type == "CS" and ins.board_type == "KSH":
# KSH can buy(sell) 201, 202 shares
return 0 if abs(quantity) < KSH_MIN_AMOUNT else int(quantity)
elif ins.type == "CS" and ins.board_type == "BJS":
# BJSE can buy(sell) 101, 202 shares
return 0 if abs(quantity) < BJSE_MIN_AMOUNT else int(quantity)
else:
round_lot = ins.round_lot
try:
return method(Decimal(quantity) / Decimal(round_lot)) * round_lot
except ValueError:
raise


def _get_order_style_price(order_book_id, style):
if isinstance(style, LimitOrder):
return style.get_limit_price()
Expand Down Expand Up @@ -116,7 +98,7 @@ def _submit_order(order_book_id: str, amount, side, position_effect, style, curr

if (side == SIDE.BUY and current_quantity != -amount) or (side == SIDE.SELL and current_quantity != abs(amount)):
# 在融券回测中,需要用买单作为平空,对于此种情况下出现的碎股,亦允许一次性申报卖出
amount = _round_order_quantity(ins, amount)
amount = round_order_quantity(ins, amount)

if amount == 0:
if zero_amount_as_exception:
Expand All @@ -137,16 +119,6 @@ def _submit_order(order_book_id: str, amount, side, position_effect, style, curr
def _order_shares(order_book_id: str, amount, style, quantity, auto_switch_order_value, zero_amount_as_exception=True):
side, position_effect = (SIDE.BUY, POSITION_EFFECT.OPEN) if amount > 0 else (SIDE.SELL, POSITION_EFFECT.CLOSE)
return _submit_order(order_book_id, amount, side, position_effect, style, quantity, auto_switch_order_value, zero_amount_as_exception)


def _estimate_transaction_cost(env: Environment, ins: Instrument, delta_quantity: Union[int, float], price: float) -> float:
if delta_quantity > 0:
side, position_effect = SIDE.BUY, POSITION_EFFECT.OPEN
else:
side, position_effect = SIDE.SELL, POSITION_EFFECT.CLOSE
return env.calc_transaction_cost(TransactionCostArgs(
ins, price, abs(delta_quantity), side, position_effect, # type: ignore
)).total


def _order_value(account: Account, position: AbstractPosition, order_book_id: str, cash_amount: float, style: OrderStyle, zero_amount_as_exception=True):
Expand All @@ -171,24 +143,12 @@ def _order_value(account: Account, position: AbstractPosition, order_book_id: st
ins = assure_active_ins_for_order_api(order_book_id)
if ins is None:
return
exchange_rates = env.data_proxy.get_exchange_rate(env.trading_dt.date(), ins.market)
exchange_rate_middle = (exchange_rates.bid_reference + exchange_rates.ask_reference) / 2
amount = int(Decimal(cash_amount) / Decimal(price * exchange_rate_middle))
if cash_amount > 0:
amount = min(amount, int(Decimal(account.cash) / Decimal(price * exchange_rates.ask_reference)))
round_lot = int(ins.round_lot)
if cash_amount > 0:
amount = _round_order_quantity(ins, amount)
while amount > 0:
expected_transaction_cost = _estimate_transaction_cost(env, ins, amount, price)
if amount * price * exchange_rates.ask_reference + expected_transaction_cost <= cash_amount:
break
amount -= round_lot
else:
if zero_amount_as_exception:
reason = _(u"Order Creation Failed: 0 order quantity, order_book_id={order_book_id}").format(order_book_id=ins.order_book_id)
env.order_creation_failed(order_book_id=order_book_id, reason=reason)
return

amount = get_amount_from_value(cash_amount, ins, price, env, account.cash)
if amount == 0 and zero_amount_as_exception:
reason = _(u"Order Creation Failed: 0 order quantity, order_book_id={order_book_id}").format(order_book_id=ins.order_book_id)
env.order_creation_failed(order_book_id=order_book_id, reason=reason)
return

if amount < 0:
amount = max(amount, -position.closable)
Expand Down Expand Up @@ -428,7 +388,7 @@ def order_target_portfolio(
for order_book_id, (target_percent, open_style, close_style, last_price, ins) in target.items():
current_value = current_quantities.get(order_book_id, 0) * last_price
change_value = target_percent * account_value - current_value
estimate_transaction_cost += _estimate_transaction_cost(env, ins, change_value / last_price, last_price)
estimate_transaction_cost += estimate_transaction_cost_calculator(env, ins, change_value / last_price, last_price)
account_value = account_value - estimate_transaction_cost

close_orders, open_orders = [], []
Expand All @@ -443,7 +403,7 @@ def order_target_portfolio(
env.order_creation_failed(order_book_id=order_book_id, reason=reason)
continue
delta_quantity = (account_value * target_percent / close_price) - current_quantities.get(order_book_id, 0)
delta_quantity = _round_order_quantity(env.data_proxy.instrument(order_book_id), delta_quantity, method=round)
delta_quantity = round_order_quantity(env.data_proxy.instrument(order_book_id), delta_quantity, method=round)

# 优先生成卖单,以便计算出剩余现金,进行买单数量的计算
if delta_quantity == 0:
Expand All @@ -460,13 +420,13 @@ def order_target_portfolio(

estimate_cash = account.cash + sum([o.quantity * o.frozen_price - o.estimated_transaction_cost for o in close_orders])
for order_book_id, (delta_quantity, position_effect, open_style, last_price, ins) in waiting_to_buy.items():
cost = delta_quantity * last_price + _estimate_transaction_cost(env, ins, delta_quantity, last_price)
cost = delta_quantity * last_price + estimate_transaction_cost_calculator(env, ins, delta_quantity, last_price)
if cost > estimate_cash:
delta_quantity = estimate_cash / last_price
delta_quantity = _round_order_quantity(env.data_proxy.instrument(order_book_id), delta_quantity)
delta_quantity = round_order_quantity(env.data_proxy.instrument(order_book_id), delta_quantity)
if delta_quantity == 0:
continue
cost = delta_quantity * last_price + _estimate_transaction_cost(env, ins, delta_quantity, last_price)
cost = delta_quantity * last_price + estimate_transaction_cost_calculator(env, ins, delta_quantity, last_price)
order = Order.__from_create__(order_book_id, delta_quantity, SIDE.BUY, open_style, position_effect)
if isinstance(open_style, MarketOrder):
order.set_frozen_price(last_price)
Expand Down
18 changes: 11 additions & 7 deletions rqalpha/mod/rqalpha_mod_sys_accounts/position_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@
from rqalpha.data.data_proxy import DataProxy
from rqalpha.utils import INST_TYPE_IN_STOCK_ACCOUNT, is_valid_price
from rqalpha.utils.datetime_func import convert_date_to_date_int
from rqalpha.utils.logger import user_system_log
from rqalpha.utils.logger import user_system_log, system_log
from rqalpha.utils.class_helper import deprecated_property
from rqalpha.utils.i18n import gettext as _
from rqalpha.core.events import EVENT, Event
from rqalpha.utils.class_helper import cached_property
from .trade_utils import get_amount_from_value


def _int_to_date(d):
Expand Down Expand Up @@ -251,16 +252,15 @@ def _handle_dividend_payable(self, trading_date: date) -> float:
payable_value += dividend_value
if payable_value and self.dividend_reinvestment:
last_price = self.last_price
amount = int(Decimal(payable_value) / Decimal(last_price))
round_lot = self._instrument.round_lot
amount = int(Decimal(amount) / Decimal(round_lot)) * round_lot
account = self._env.get_account(self._order_book_id)
amount = get_amount_from_value(payable_value, self._instrument, last_price, self._env, account.cash)
if amount > 0:
account = self._env.get_account(self._order_book_id)
trade = Trade.__from_create__(
None, last_price, amount, SIDE.BUY, POSITION_EFFECT.OPEN, self._order_book_id,
)
self._env.event_bus.publish_event(Event(EVENT.TRADE, account=account, trade=trade, order=None))
return payable_value - amount * last_price
return payable_value - amount * last_price - trade.transaction_cost
return payable_value
else:
return payable_value

Expand Down Expand Up @@ -390,7 +390,11 @@ def settlement(self, trading_date):
next_date = data_proxy.get_next_trading_date(trading_date)
if self._env.config.mod.sys_accounts.futures_settlement_price_type == "settlement":
# 逐日盯市按照结算价结算
self._last_price = self._env.data_proxy.get_settle_price(self._order_book_id, self._env.trading_dt)
settle_price = self._env.data_proxy.get_settle_price(self._order_book_id, self._env.trading_dt)
if not is_valid_price(settle_price):
system_log.warning(f"{self._order_book_id} is missing settlement data for {self._env.trading_dt}, close price will be used.")
else:
self._last_price = settle_price
delta_cash += self.equity
self._avg_price = self.last_price
if self._instrument.de_listed_at(next_date):
Expand Down
52 changes: 52 additions & 0 deletions rqalpha/mod/rqalpha_mod_sys_accounts/trade_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Union, Callable
from decimal import Decimal

from rqalpha.environment import Environment
from rqalpha.model.instrument import Instrument
from rqalpha.const import SIDE, POSITION_EFFECT
from rqalpha.interface import TransactionCostArgs


KSH_MIN_AMOUNT = 200
BJSE_MIN_AMOUNT = 100


def estimate_transaction_cost_calculator(env: Environment, ins: Instrument, delta_quantity: Union[int, float], price: float) -> float:
if delta_quantity > 0:
side, position_effect = SIDE.BUY, POSITION_EFFECT.OPEN
else:
side, position_effect = SIDE.SELL, POSITION_EFFECT.CLOSE
return env.calc_transaction_cost(TransactionCostArgs(
ins, price, abs(delta_quantity), side, position_effect, # type: ignore
)).total


def round_order_quantity(ins, quantity, method: Callable = int) -> int:
if ins.type == "CS" and ins.board_type == "KSH":
# KSH can buy(sell) 201, 202 shares
return 0 if abs(quantity) < KSH_MIN_AMOUNT else int(quantity)
elif ins.type == "CS" and ins.board_type == "BJS":
# BJSE can buy(sell) 101, 202 shares
return 0 if abs(quantity) < BJSE_MIN_AMOUNT else int(quantity)
else:
round_lot = ins.round_lot
try:
return method(Decimal(quantity) / Decimal(round_lot)) * round_lot
except ValueError:
raise


def get_amount_from_value(value: float, ins: Instrument, price: float, env: Environment, account_cash: float) -> int:
exchange_rates = env.data_proxy.get_exchange_rate(env.trading_dt.date(), ins.market)
exchange_rate_middle = (exchange_rates.bid_reference + exchange_rates.ask_reference) / 2
amount = int(Decimal(value) / Decimal(price * exchange_rate_middle))
if value > 0:
amount = min(amount, int(Decimal(account_cash) / Decimal(price * exchange_rates.ask_reference)))
amount = round_order_quantity(ins, amount)
while amount > 0:
estimate_transaction_cost = estimate_transaction_cost_calculator(env, ins, amount, price)
if amount * price + estimate_transaction_cost > value:
amount = round_order_quantity(ins, amount - ins.order_step_size)
else:
return amount
return amount
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,46 @@ def handle_bar(context, bar_dict):
# 2. 每 10 股分现金 6 元,总共分 6000 元,再投资买入 400 股
assert get_position(context.s1).quantity == 22400

run_func(config=config, init=init, handle_bar=handle_bar)


def test_dividend_reinvestment_with_transaction():
"""
测试分红再投资时考虑手续费
"""
config = _config({
"base": {
"start_date": "2013-06-10",
"end_date": "2013-06-21",
"accounts": {
"stock": 10000,
},
"init_positions": "000001.XSHE:15000",
},
"extra": {
"log_level": "error",
},
"mod": {
"sys_accounts": {
"dividend_reinvestment": True
}
}
})

def init(context):
context.s1 = "000001.XSHE"
context.fired = False

def handle_bar(context, bar_dict):
if not context.fired:
assert get_position(context.s1).quantity == 15000
context.fired = True
if context.now.date() == date(2013, 6, 20):
# 1. 每 1 股拆为 1.6 股
# 2. 每 10 股分现金 1.7 元,总共分 2550 元,再投资买入 200 股
assert get_position(context.s1).quantity == 24200
# 分红再投资的剩余现金为 2550 - (200 * 11.9187 + 5)
assert context.stock_account.cash - 10000 == 161.25


run_func(config=config, init=init, handle_bar=handle_bar)
Loading