diff --git a/rqalpha/data/data_proxy.py b/rqalpha/data/data_proxy.py index 218464e80..bdaa6cda0 100644 --- a/rqalpha/data/data_proxy.py +++ b/rqalpha/data/data_proxy.py @@ -16,7 +16,7 @@ # 详细的授权流程,请联系 public@ricequant.com 获取。 from datetime import datetime, date -from typing import Union, List, Sequence, Optional, Tuple, Iterable, Dict, Callable +from typing import Union, List, Sequence, Optional, Tuple import numpy as np import pandas as pd @@ -36,9 +36,11 @@ from rqalpha.core.execution_context import ExecutionContext from rqalpha.utils.typing import DateLike from rqalpha.utils.exception import InstrumentNotFound +from rqalpha.data.base_data_source.storages import FuturesTradingParameters from .instruments_mixin import InstrumentsMixin + class DataProxy(TradingDatesMixin, InstrumentsMixin): def __init__(self, data_source: AbstractDataSource, price_board: AbstractPriceBoard): self._data_source = data_source @@ -152,18 +154,18 @@ def _get_settlement(self, instrument, dt): def get_prev_settlement(self, order_book_id, dt): instrument = self.instruments(order_book_id) - if instrument.type not in (INSTRUMENT_TYPE.FUTURE, INSTRUMENT_TYPE.OPTION): + if instrument.type not in (INSTRUMENT_TYPE.FUTURE, INSTRUMENT_TYPE.OPTION, INSTRUMENT_TYPE.SPOT): return np.nan return self._get_prev_settlement(instrument, dt) def get_settlement(self, instrument: Instrument, dt: datetime) -> float: - if instrument.type != INSTRUMENT_TYPE.FUTURE: + if instrument.type not in (INSTRUMENT_TYPE.FUTURE, INSTRUMENT_TYPE.OPTION, INSTRUMENT_TYPE.SPOT): raise LookupError("'{}', instrument_type={}".format(instrument.order_book_id, instrument.type)) return self._get_settlement(instrument, dt) def get_settle_price(self, order_book_id, trading_dt: datetime): instrument = self.get_active_instrument(order_book_id, trading_dt) - if instrument.type != 'Future': + if instrument.type not in (INSTRUMENT_TYPE.FUTURE, INSTRUMENT_TYPE.OPTION, INSTRUMENT_TYPE.SPOT): return np.nan return self._data_source.get_settle_price(instrument, trading_dt) @@ -260,8 +262,7 @@ def tick_fields_for(ins): def available_data_range(self, frequency): return self._data_source.available_data_range(frequency) - def get_futures_trading_parameters(self, order_book_id, dt): - # type: (str, datetime.date) -> FuturesTradingParameters + def get_futures_trading_parameters(self, order_book_id: str, dt: datetime.date) -> FuturesTradingParameters: instrument = self.instruments(order_book_id) return self._data_source.get_futures_trading_parameters(instrument, dt) @@ -288,8 +289,7 @@ def get_tick_size(self, order_book_id): def get_last_price(self, order_book_id: str) -> float: return float(self._price_board.get_last_price(order_book_id)) - def get_future_contracts(self, underlying, date): - # type: (str, DateLike) -> List[str] + def get_future_contracts(self, underlying: str, date: DateLike) -> List[str]: return sorted(i.order_book_id for i in self.all_instruments( [INSTRUMENT_TYPE.FUTURE], date ) if i.underlying_symbol == underlying and not Instrument.is_future_continuous_contract(i.order_book_id)) diff --git a/rqalpha/mod/rqalpha_mod_sys_accounts/api/api_stock.py b/rqalpha/mod/rqalpha_mod_sys_accounts/api/api_stock.py index 2446fb311..305783faf 100644 --- a/rqalpha/mod/rqalpha_mod_sys_accounts/api/api_stock.py +++ b/rqalpha/mod/rqalpha_mod_sys_accounts/api/api_stock.py @@ -50,6 +50,7 @@ from rqalpha.utils.exception import RQInvalidArgument from rqalpha.utils.i18n import gettext as _ from rqalpha.utils.logger import user_system_log +from rqalpha.mod.rqalpha_mod_sys_accounts.trade_utils import estimate_transaction_cost_calculator, round_order_quantity, get_amount_from_value from .order_target_portfolio import order_target_portfolio_smart # 使用Decimal 解决浮点数运算精度问题 @@ -59,10 +60,6 @@ export_as_api(sector_code, name='sector_code') -KSH_MIN_AMOUNT = 200 -BJSE_MIN_AMOUNT = 100 - - def _get_account_position(order_book_id: str): try: account = Environment.get_instance().portfolio.accounts[DEFAULT_ACCOUNT_TYPE.STOCK] @@ -74,21 +71,6 @@ def _get_account_position(order_book_id: str): return account, position -def _round_order_quantity(ins, quantity, method: Callable = int) -> int: - if ins.type == "CS" and ins.board_type == "KSH": - # KSH can buy(sell) 201, 202 shares - return 0 if abs(quantity) < KSH_MIN_AMOUNT else int(quantity) - elif ins.type == "CS" and ins.board_type == "BJS": - # BJSE can buy(sell) 101, 202 shares - return 0 if abs(quantity) < BJSE_MIN_AMOUNT else int(quantity) - else: - round_lot = ins.round_lot - try: - return method(Decimal(quantity) / Decimal(round_lot)) * round_lot - except ValueError: - raise - - def _get_order_style_price(order_book_id, style): if isinstance(style, LimitOrder): return style.get_limit_price() @@ -116,7 +98,7 @@ def _submit_order(order_book_id: str, amount, side, position_effect, style, curr if (side == SIDE.BUY and current_quantity != -amount) or (side == SIDE.SELL and current_quantity != abs(amount)): # 在融券回测中,需要用买单作为平空,对于此种情况下出现的碎股,亦允许一次性申报卖出 - amount = _round_order_quantity(ins, amount) + amount = round_order_quantity(ins, amount) if amount == 0: if zero_amount_as_exception: @@ -137,16 +119,6 @@ def _submit_order(order_book_id: str, amount, side, position_effect, style, curr def _order_shares(order_book_id: str, amount, style, quantity, auto_switch_order_value, zero_amount_as_exception=True): side, position_effect = (SIDE.BUY, POSITION_EFFECT.OPEN) if amount > 0 else (SIDE.SELL, POSITION_EFFECT.CLOSE) return _submit_order(order_book_id, amount, side, position_effect, style, quantity, auto_switch_order_value, zero_amount_as_exception) - - -def _estimate_transaction_cost(env: Environment, ins: Instrument, delta_quantity: Union[int, float], price: float) -> float: - if delta_quantity > 0: - side, position_effect = SIDE.BUY, POSITION_EFFECT.OPEN - else: - side, position_effect = SIDE.SELL, POSITION_EFFECT.CLOSE - return env.calc_transaction_cost(TransactionCostArgs( - ins, price, abs(delta_quantity), side, position_effect, # type: ignore - )).total def _order_value(account: Account, position: AbstractPosition, order_book_id: str, cash_amount: float, style: OrderStyle, zero_amount_as_exception=True): @@ -171,24 +143,12 @@ def _order_value(account: Account, position: AbstractPosition, order_book_id: st ins = assure_active_ins_for_order_api(order_book_id) if ins is None: return - exchange_rates = env.data_proxy.get_exchange_rate(env.trading_dt.date(), ins.market) - exchange_rate_middle = (exchange_rates.bid_reference + exchange_rates.ask_reference) / 2 - amount = int(Decimal(cash_amount) / Decimal(price * exchange_rate_middle)) - if cash_amount > 0: - amount = min(amount, int(Decimal(account.cash) / Decimal(price * exchange_rates.ask_reference))) - round_lot = int(ins.round_lot) - if cash_amount > 0: - amount = _round_order_quantity(ins, amount) - while amount > 0: - expected_transaction_cost = _estimate_transaction_cost(env, ins, amount, price) - if amount * price * exchange_rates.ask_reference + expected_transaction_cost <= cash_amount: - break - amount -= round_lot - else: - if zero_amount_as_exception: - reason = _(u"Order Creation Failed: 0 order quantity, order_book_id={order_book_id}").format(order_book_id=ins.order_book_id) - env.order_creation_failed(order_book_id=order_book_id, reason=reason) - return + + amount = get_amount_from_value(cash_amount, ins, price, env, account.cash) + if amount == 0 and zero_amount_as_exception: + reason = _(u"Order Creation Failed: 0 order quantity, order_book_id={order_book_id}").format(order_book_id=ins.order_book_id) + env.order_creation_failed(order_book_id=order_book_id, reason=reason) + return if amount < 0: amount = max(amount, -position.closable) @@ -428,7 +388,7 @@ def order_target_portfolio( for order_book_id, (target_percent, open_style, close_style, last_price, ins) in target.items(): current_value = current_quantities.get(order_book_id, 0) * last_price change_value = target_percent * account_value - current_value - estimate_transaction_cost += _estimate_transaction_cost(env, ins, change_value / last_price, last_price) + estimate_transaction_cost += estimate_transaction_cost_calculator(env, ins, change_value / last_price, last_price) account_value = account_value - estimate_transaction_cost close_orders, open_orders = [], [] @@ -443,7 +403,7 @@ def order_target_portfolio( env.order_creation_failed(order_book_id=order_book_id, reason=reason) continue delta_quantity = (account_value * target_percent / close_price) - current_quantities.get(order_book_id, 0) - delta_quantity = _round_order_quantity(env.data_proxy.instrument(order_book_id), delta_quantity, method=round) + delta_quantity = round_order_quantity(env.data_proxy.instrument(order_book_id), delta_quantity, method=round) # 优先生成卖单,以便计算出剩余现金,进行买单数量的计算 if delta_quantity == 0: @@ -460,13 +420,13 @@ def order_target_portfolio( estimate_cash = account.cash + sum([o.quantity * o.frozen_price - o.estimated_transaction_cost for o in close_orders]) for order_book_id, (delta_quantity, position_effect, open_style, last_price, ins) in waiting_to_buy.items(): - cost = delta_quantity * last_price + _estimate_transaction_cost(env, ins, delta_quantity, last_price) + cost = delta_quantity * last_price + estimate_transaction_cost_calculator(env, ins, delta_quantity, last_price) if cost > estimate_cash: delta_quantity = estimate_cash / last_price - delta_quantity = _round_order_quantity(env.data_proxy.instrument(order_book_id), delta_quantity) + delta_quantity = round_order_quantity(env.data_proxy.instrument(order_book_id), delta_quantity) if delta_quantity == 0: continue - cost = delta_quantity * last_price + _estimate_transaction_cost(env, ins, delta_quantity, last_price) + cost = delta_quantity * last_price + estimate_transaction_cost_calculator(env, ins, delta_quantity, last_price) order = Order.__from_create__(order_book_id, delta_quantity, SIDE.BUY, open_style, position_effect) if isinstance(open_style, MarketOrder): order.set_frozen_price(last_price) diff --git a/rqalpha/mod/rqalpha_mod_sys_accounts/position_model.py b/rqalpha/mod/rqalpha_mod_sys_accounts/position_model.py index 79b7b7f1a..cedf316de 100644 --- a/rqalpha/mod/rqalpha_mod_sys_accounts/position_model.py +++ b/rqalpha/mod/rqalpha_mod_sys_accounts/position_model.py @@ -29,11 +29,12 @@ from rqalpha.data.data_proxy import DataProxy from rqalpha.utils import INST_TYPE_IN_STOCK_ACCOUNT, is_valid_price from rqalpha.utils.datetime_func import convert_date_to_date_int -from rqalpha.utils.logger import user_system_log +from rqalpha.utils.logger import user_system_log, system_log from rqalpha.utils.class_helper import deprecated_property from rqalpha.utils.i18n import gettext as _ from rqalpha.core.events import EVENT, Event from rqalpha.utils.class_helper import cached_property +from .trade_utils import get_amount_from_value def _int_to_date(d): @@ -251,16 +252,15 @@ def _handle_dividend_payable(self, trading_date: date) -> float: payable_value += dividend_value if payable_value and self.dividend_reinvestment: last_price = self.last_price - amount = int(Decimal(payable_value) / Decimal(last_price)) - round_lot = self._instrument.round_lot - amount = int(Decimal(amount) / Decimal(round_lot)) * round_lot + account = self._env.get_account(self._order_book_id) + amount = get_amount_from_value(payable_value, self._instrument, last_price, self._env, account.cash) if amount > 0: - account = self._env.get_account(self._order_book_id) trade = Trade.__from_create__( None, last_price, amount, SIDE.BUY, POSITION_EFFECT.OPEN, self._order_book_id, ) self._env.event_bus.publish_event(Event(EVENT.TRADE, account=account, trade=trade, order=None)) - return payable_value - amount * last_price + return payable_value - amount * last_price - trade.transaction_cost + return payable_value else: return payable_value @@ -390,7 +390,11 @@ def settlement(self, trading_date): next_date = data_proxy.get_next_trading_date(trading_date) if self._env.config.mod.sys_accounts.futures_settlement_price_type == "settlement": # 逐日盯市按照结算价结算 - self._last_price = self._env.data_proxy.get_settle_price(self._order_book_id, self._env.trading_dt) + settle_price = self._env.data_proxy.get_settle_price(self._order_book_id, self._env.trading_dt) + if not is_valid_price(settle_price): + system_log.warning(f"{self._order_book_id} is missing settlement data for {self._env.trading_dt}, close price will be used.") + else: + self._last_price = settle_price delta_cash += self.equity self._avg_price = self.last_price if self._instrument.de_listed_at(next_date): diff --git a/rqalpha/mod/rqalpha_mod_sys_accounts/trade_utils.py b/rqalpha/mod/rqalpha_mod_sys_accounts/trade_utils.py new file mode 100644 index 000000000..21a9fc8c9 --- /dev/null +++ b/rqalpha/mod/rqalpha_mod_sys_accounts/trade_utils.py @@ -0,0 +1,52 @@ +from typing import Union, Callable +from decimal import Decimal + +from rqalpha.environment import Environment +from rqalpha.model.instrument import Instrument +from rqalpha.const import SIDE, POSITION_EFFECT +from rqalpha.interface import TransactionCostArgs + + +KSH_MIN_AMOUNT = 200 +BJSE_MIN_AMOUNT = 100 + + +def estimate_transaction_cost_calculator(env: Environment, ins: Instrument, delta_quantity: Union[int, float], price: float) -> float: + if delta_quantity > 0: + side, position_effect = SIDE.BUY, POSITION_EFFECT.OPEN + else: + side, position_effect = SIDE.SELL, POSITION_EFFECT.CLOSE + return env.calc_transaction_cost(TransactionCostArgs( + ins, price, abs(delta_quantity), side, position_effect, # type: ignore + )).total + + +def round_order_quantity(ins, quantity, method: Callable = int) -> int: + if ins.type == "CS" and ins.board_type == "KSH": + # KSH can buy(sell) 201, 202 shares + return 0 if abs(quantity) < KSH_MIN_AMOUNT else int(quantity) + elif ins.type == "CS" and ins.board_type == "BJS": + # BJSE can buy(sell) 101, 202 shares + return 0 if abs(quantity) < BJSE_MIN_AMOUNT else int(quantity) + else: + round_lot = ins.round_lot + try: + return method(Decimal(quantity) / Decimal(round_lot)) * round_lot + except ValueError: + raise + + +def get_amount_from_value(value: float, ins: Instrument, price: float, env: Environment, account_cash: float) -> int: + exchange_rates = env.data_proxy.get_exchange_rate(env.trading_dt.date(), ins.market) + exchange_rate_middle = (exchange_rates.bid_reference + exchange_rates.ask_reference) / 2 + amount = int(Decimal(value) / Decimal(price * exchange_rate_middle)) + if value > 0: + amount = min(amount, int(Decimal(account_cash) / Decimal(price * exchange_rates.ask_reference))) + amount = round_order_quantity(ins, amount) + while amount > 0: + estimate_transaction_cost = estimate_transaction_cost_calculator(env, ins, amount, price) + if amount * price + estimate_transaction_cost > value: + amount = round_order_quantity(ins, amount - ins.order_step_size) + else: + return amount + return amount \ No newline at end of file diff --git a/tests/integration_tests/test_api/mod/sys_accounts/test_position_models.py b/tests/integration_tests/test_api/mod/sys_accounts/test_position_models.py index 43da0d065..db9bf0ff2 100644 --- a/tests/integration_tests/test_api/mod/sys_accounts/test_position_models.py +++ b/tests/integration_tests/test_api/mod/sys_accounts/test_position_models.py @@ -139,4 +139,46 @@ def handle_bar(context, bar_dict): # 2. 每 10 股分现金 6 元,总共分 6000 元,再投资买入 400 股 assert get_position(context.s1).quantity == 22400 + run_func(config=config, init=init, handle_bar=handle_bar) + + +def test_dividend_reinvestment_with_transaction(): + """ + 测试分红再投资时考虑手续费 + """ + config = _config({ + "base": { + "start_date": "2013-06-10", + "end_date": "2013-06-21", + "accounts": { + "stock": 10000, + }, + "init_positions": "000001.XSHE:15000", + }, + "extra": { + "log_level": "error", + }, + "mod": { + "sys_accounts": { + "dividend_reinvestment": True + } + } + }) + + def init(context): + context.s1 = "000001.XSHE" + context.fired = False + + def handle_bar(context, bar_dict): + if not context.fired: + assert get_position(context.s1).quantity == 15000 + context.fired = True + if context.now.date() == date(2013, 6, 20): + # 1. 每 1 股拆为 1.6 股 + # 2. 每 10 股分现金 1.7 元,总共分 2550 元,再投资买入 200 股 + assert get_position(context.s1).quantity == 24200 + # 分红再投资的剩余现金为 2550 - (200 * 11.9187 + 5) + assert context.stock_account.cash - 10000 == 161.25 + + run_func(config=config, init=init, handle_bar=handle_bar) \ No newline at end of file