Skip to content

Commit 0cc4e21

Browse files
authored
ZJIT: Get type information from branchif, branchunless, branchnil instructions (ruby#15915)
Do a sort of "partial static single information (SSI)" form that learns types of operands from branch instructions. A branchif, for example, tells us that in the truthy path, we know the operand is not nil, and not false. Similarly, in the falsy path, we know the operand is either nil or false. Add a RefineType instruction to attach this information. This PR does this in SSA construction because it's pretty straightforward, but we can also do a more aggressive version of this that can learn information about e.g. int ranges from other checks later in the optimization pipeline.
1 parent cfa97af commit 0cc4e21

7 files changed

Lines changed: 324 additions & 133 deletions

File tree

zjit/src/codegen.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
524524
&Insn::BoxFixnum { val, state } => gen_box_fixnum(jit, asm, opnd!(val), &function.frame_state(state)),
525525
&Insn::UnboxFixnum { val } => gen_unbox_fixnum(asm, opnd!(val)),
526526
Insn::Test { val } => gen_test(asm, opnd!(val)),
527+
Insn::RefineType { val, .. } => opnd!(val),
527528
Insn::GuardType { val, guard_type, state } => gen_guard_type(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)),
528529
Insn::GuardTypeNot { val, guard_type, state } => gen_guard_type_not(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)),
529530
&Insn::GuardBitEquals { val, expected, reason, state } => gen_guard_bit_equals(jit, asm, opnd!(val), expected, reason, &function.frame_state(state)),

zjit/src/hir.rs

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,10 @@ pub enum Insn {
994994
ObjToString { val: InsnId, cd: *const rb_call_data, state: InsnId },
995995
AnyToString { val: InsnId, str: InsnId, state: InsnId },
996996

997+
/// Refine the known type information of with additional type information.
998+
/// Computes the intersection of the existing type and the new type.
999+
RefineType { val: InsnId, new_type: Type },
1000+
9971001
/// Side-exit if val doesn't have the expected type.
9981002
GuardType { val: InsnId, guard_type: Type, state: InsnId },
9991003
GuardTypeNot { val: InsnId, guard_type: Type, state: InsnId },
@@ -1212,6 +1216,7 @@ impl Insn {
12121216
Insn::IncrCounterPtr { .. } => effects::Any,
12131217
Insn::CheckInterrupts { .. } => effects::Any,
12141218
Insn::InvokeProc { .. } => effects::Any,
1219+
Insn::RefineType { .. } => effects::Empty,
12151220
}
12161221
}
12171222

@@ -1507,6 +1512,7 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
15071512
Insn::FixnumLShift { left, right, .. } => { write!(f, "FixnumLShift {left}, {right}") },
15081513
Insn::FixnumRShift { left, right, .. } => { write!(f, "FixnumRShift {left}, {right}") },
15091514
Insn::GuardType { val, guard_type, .. } => { write!(f, "GuardType {val}, {}", guard_type.print(self.ptr_map)) },
1515+
Insn::RefineType { val, new_type, .. } => { write!(f, "RefineType {val}, {}", new_type.print(self.ptr_map)) },
15101516
Insn::GuardTypeNot { val, guard_type, .. } => { write!(f, "GuardTypeNot {val}, {}", guard_type.print(self.ptr_map)) },
15111517
Insn::GuardBitEquals { val, expected, .. } => { write!(f, "GuardBitEquals {val}, {}", expected.print(self.ptr_map)) },
15121518
&Insn::GuardShape { val, shape, .. } => { write!(f, "GuardShape {val}, {:p}", self.ptr_map.map_shape(shape)) },
@@ -2174,6 +2180,7 @@ impl Function {
21742180
Jump(target) => Jump(find_branch_edge!(target)),
21752181
&IfTrue { val, ref target } => IfTrue { val: find!(val), target: find_branch_edge!(target) },
21762182
&IfFalse { val, ref target } => IfFalse { val: find!(val), target: find_branch_edge!(target) },
2183+
&RefineType { val, new_type } => RefineType { val: find!(val), new_type },
21772184
&GuardType { val, guard_type, state } => GuardType { val: find!(val), guard_type, state },
21782185
&GuardTypeNot { val, guard_type, state } => GuardTypeNot { val: find!(val), guard_type, state },
21792186
&GuardBitEquals { val, expected, reason, state } => GuardBitEquals { val: find!(val), expected, reason, state },
@@ -2423,6 +2430,7 @@ impl Function {
24232430
Insn::CCall { return_type, .. } => *return_type,
24242431
&Insn::CCallVariadic { return_type, .. } => return_type,
24252432
Insn::GuardType { val, guard_type, .. } => self.type_of(*val).intersection(*guard_type),
2433+
Insn::RefineType { val, new_type, .. } => self.type_of(*val).intersection(*new_type),
24262434
Insn::GuardTypeNot { .. } => types::BasicObject,
24272435
Insn::GuardBitEquals { val, expected, .. } => self.type_of(*val).intersection(Type::from_const(*expected)),
24282436
Insn::GuardShape { val, .. } => self.type_of(*val),
@@ -2594,6 +2602,7 @@ impl Function {
25942602
| Insn::GuardTypeNot { val, .. }
25952603
| Insn::GuardShape { val, .. }
25962604
| Insn::GuardBitEquals { val, .. } => self.chase_insn(val),
2605+
| Insn::RefineType { val, .. } => self.chase_insn(val),
25972606
_ => id,
25982607
}
25992608
}
@@ -4445,6 +4454,7 @@ impl Function {
44454454
worklist.extend(values);
44464455
worklist.push_back(state);
44474456
}
4457+
| &Insn::RefineType { val, .. }
44484458
| &Insn::Return { val }
44494459
| &Insn::Test { val }
44504460
| &Insn::SetLocal { val, .. }
@@ -5370,6 +5380,7 @@ impl Function {
53705380
self.assert_subtype(insn_id, val, types::BasicObject)?;
53715381
self.assert_subtype(insn_id, class, types::Class)
53725382
}
5383+
Insn::RefineType { .. } => Ok(()),
53735384
}
53745385
}
53755386

@@ -5562,6 +5573,19 @@ impl FrameState {
55625573
state.stack.extend_from_slice(new_args);
55635574
state
55645575
}
5576+
5577+
fn replace(&mut self, old: InsnId, new: InsnId) {
5578+
for slot in &mut self.stack {
5579+
if *slot == old {
5580+
*slot = new;
5581+
}
5582+
}
5583+
for slot in &mut self.locals {
5584+
if *slot == old {
5585+
*slot = new;
5586+
}
5587+
}
5588+
}
55655589
}
55665590

55675591
/// Print adaptor for [`FrameState`]. See [`PtrPrintMap`].
@@ -6245,10 +6269,17 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
62456269
let test_id = fun.push_insn(block, Insn::Test { val });
62466270
let target_idx = insn_idx_at_offset(insn_idx, offset);
62476271
let target = insn_idx_to_block[&target_idx];
6272+
let nil_false_type = types::NilClass.union(types::FalseClass);
6273+
let nil_false = fun.push_insn(block, Insn::RefineType { val, new_type: nil_false_type });
6274+
let mut iffalse_state = state.clone();
6275+
iffalse_state.replace(val, nil_false);
62486276
let _branch_id = fun.push_insn(block, Insn::IfFalse {
62496277
val: test_id,
6250-
target: BranchEdge { target, args: state.as_args(self_param) }
6278+
target: BranchEdge { target, args: iffalse_state.as_args(self_param) }
62516279
});
6280+
let not_nil_false_type = types::BasicObject.subtract(types::NilClass).subtract(types::FalseClass);
6281+
let not_nil_false = fun.push_insn(block, Insn::RefineType { val, new_type: not_nil_false_type });
6282+
state.replace(val, not_nil_false);
62526283
queue.push_back((state.clone(), target, target_idx, local_inval));
62536284
}
62546285
YARVINSN_branchif => {
@@ -6258,10 +6289,17 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
62586289
let test_id = fun.push_insn(block, Insn::Test { val });
62596290
let target_idx = insn_idx_at_offset(insn_idx, offset);
62606291
let target = insn_idx_to_block[&target_idx];
6292+
let not_nil_false_type = types::BasicObject.subtract(types::NilClass).subtract(types::FalseClass);
6293+
let not_nil_false = fun.push_insn(block, Insn::RefineType { val, new_type: not_nil_false_type });
6294+
let mut iftrue_state = state.clone();
6295+
iftrue_state.replace(val, not_nil_false);
62616296
let _branch_id = fun.push_insn(block, Insn::IfTrue {
62626297
val: test_id,
6263-
target: BranchEdge { target, args: state.as_args(self_param) }
6298+
target: BranchEdge { target, args: iftrue_state.as_args(self_param) }
62646299
});
6300+
let nil_false_type = types::NilClass.union(types::FalseClass);
6301+
let nil_false = fun.push_insn(block, Insn::RefineType { val, new_type: nil_false_type });
6302+
state.replace(val, nil_false);
62656303
queue.push_back((state.clone(), target, target_idx, local_inval));
62666304
}
62676305
YARVINSN_branchnil => {
@@ -6271,10 +6309,16 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
62716309
let test_id = fun.push_insn(block, Insn::IsNil { val });
62726310
let target_idx = insn_idx_at_offset(insn_idx, offset);
62736311
let target = insn_idx_to_block[&target_idx];
6312+
let nil = fun.push_insn(block, Insn::Const { val: Const::Value(Qnil) });
6313+
let mut iftrue_state = state.clone();
6314+
iftrue_state.replace(val, nil);
62746315
let _branch_id = fun.push_insn(block, Insn::IfTrue {
62756316
val: test_id,
6276-
target: BranchEdge { target, args: state.as_args(self_param) }
6317+
target: BranchEdge { target, args: iftrue_state.as_args(self_param) }
62776318
});
6319+
let new_type = types::BasicObject.subtract(types::NilClass);
6320+
let not_nil = fun.push_insn(block, Insn::RefineType { val, new_type });
6321+
state.replace(val, not_nil);
62786322
queue.push_back((state.clone(), target, target_idx, local_inval));
62796323
}
62806324
YARVINSN_opt_case_dispatch => {
@@ -7693,21 +7737,23 @@ mod graphviz_tests {
76937737
<TR><TD ALIGN="left" PORT="v12">PatchPoint NoTracePoint&nbsp;</TD></TR>
76947738
<TR><TD ALIGN="left" PORT="v14">CheckInterrupts&nbsp;</TD></TR>
76957739
<TR><TD ALIGN="left" PORT="v15">v15:CBool = Test v9&nbsp;</TD></TR>
7696-
<TR><TD ALIGN="left" PORT="v16">IfFalse v15, bb3(v8, v9)&nbsp;</TD></TR>
7697-
<TR><TD ALIGN="left" PORT="v18">PatchPoint NoTracePoint&nbsp;</TD></TR>
7698-
<TR><TD ALIGN="left" PORT="v19">v19:Fixnum[3] = Const Value(3)&nbsp;</TD></TR>
7699-
<TR><TD ALIGN="left" PORT="v21">PatchPoint NoTracePoint&nbsp;</TD></TR>
7700-
<TR><TD ALIGN="left" PORT="v22">CheckInterrupts&nbsp;</TD></TR>
7701-
<TR><TD ALIGN="left" PORT="v23">Return v19&nbsp;</TD></TR>
7740+
<TR><TD ALIGN="left" PORT="v16">v16:Falsy = RefineType v9, Falsy&nbsp;</TD></TR>
7741+
<TR><TD ALIGN="left" PORT="v17">IfFalse v15, bb3(v8, v16)&nbsp;</TD></TR>
7742+
<TR><TD ALIGN="left" PORT="v18">v18:Truthy = RefineType v9, Truthy&nbsp;</TD></TR>
7743+
<TR><TD ALIGN="left" PORT="v20">PatchPoint NoTracePoint&nbsp;</TD></TR>
7744+
<TR><TD ALIGN="left" PORT="v21">v21:Fixnum[3] = Const Value(3)&nbsp;</TD></TR>
7745+
<TR><TD ALIGN="left" PORT="v23">PatchPoint NoTracePoint&nbsp;</TD></TR>
7746+
<TR><TD ALIGN="left" PORT="v24">CheckInterrupts&nbsp;</TD></TR>
7747+
<TR><TD ALIGN="left" PORT="v25">Return v21&nbsp;</TD></TR>
77027748
</TABLE>>];
7703-
bb2:v16 -> bb3:params:n;
7749+
bb2:v17 -> bb3:params:n;
77047750
bb3 [label=<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0">
7705-
<TR><TD ALIGN="LEFT" PORT="params" BGCOLOR="gray">bb3(v24:BasicObject, v25:BasicObject)&nbsp;</TD></TR>
7706-
<TR><TD ALIGN="left" PORT="v28">PatchPoint NoTracePoint&nbsp;</TD></TR>
7707-
<TR><TD ALIGN="left" PORT="v29">v29:Fixnum[4] = Const Value(4)&nbsp;</TD></TR>
7708-
<TR><TD ALIGN="left" PORT="v31">PatchPoint NoTracePoint&nbsp;</TD></TR>
7709-
<TR><TD ALIGN="left" PORT="v32">CheckInterrupts&nbsp;</TD></TR>
7710-
<TR><TD ALIGN="left" PORT="v33">Return v29&nbsp;</TD></TR>
7751+
<TR><TD ALIGN="LEFT" PORT="params" BGCOLOR="gray">bb3(v26:BasicObject, v27:Falsy)&nbsp;</TD></TR>
7752+
<TR><TD ALIGN="left" PORT="v30">PatchPoint NoTracePoint&nbsp;</TD></TR>
7753+
<TR><TD ALIGN="left" PORT="v31">v31:Fixnum[4] = Const Value(4)&nbsp;</TD></TR>
7754+
<TR><TD ALIGN="left" PORT="v33">PatchPoint NoTracePoint&nbsp;</TD></TR>
7755+
<TR><TD ALIGN="left" PORT="v34">CheckInterrupts&nbsp;</TD></TR>
7756+
<TR><TD ALIGN="left" PORT="v35">Return v31&nbsp;</TD></TR>
77117757
</TABLE>>];
77127758
}
77137759
"#);

0 commit comments

Comments
 (0)