Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions NN/API/Models/Gpt2.lean
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,75 @@ def causalTransformerTokenScalarModuleDef (cfg : CausalOneHotConfig)
causalTransformerTokenScalarModuleDefWithMode .train cfg body tokens targets
(reduction := reduction)

abbrev causalTokenLmInputShape (cfg : CausalOneHotConfig) : Shape :=
.dim (cfg.batch * cfg.seqLen) .scalar

/--
Scalar loss for causal language modeling with per-step float-encoded token ids as inputs.

`xTokens` and `yTokens` are flattened `(batch * seqLen)` float vectors holding integer token ids.
This matches PyTorch `nn.Embedding` + `F.cross_entropy` training and avoids rebuilding the module
each step (Adam state stays on one session).
-/
def causalTransformerTokenLmScalarModuleDefWithMode
(mode : _root_.Runtime.Autograd.TorchLean.NN.Mode)
(cfg : CausalOneHotConfig)
(body : nn.Sequential (causalEmbeddingShape cfg) (causalOneHotShape cfg))
(initParams : _root_.Runtime.Autograd.Torch.TList Float
((.dim cfg.vocab (.dim cfg.dModel .scalar)) :: paramShapes body))
(reduction : TorchLean.Loss.Reduction := .mean) :
TorchLean.Module.ScalarModuleDef
((.dim cfg.vocab (.dim cfg.dModel .scalar)) :: paramShapes body)
[causalTokenLmInputShape cfg, causalTokenLmInputShape cfg] :=
{ initParams := initParams
initRequiresGrad :=
List.replicate (((.dim cfg.vocab (.dim cfg.dModel .scalar)) :: paramShapes body).length) true
loss := fun {α} => by
intro _ _
exact fun {m} _ _ =>
_root_.Runtime.Autograd.Torch.CurriedRef.curry
(Ref := _root_.Runtime.Autograd.TorchLean.NN.Seq.RefT (m := m) (α := α))
(ss := ((.dim cfg.vocab (.dim cfg.dModel .scalar)) :: paramShapes body) ++
[causalTokenLmInputShape cfg, causalTokenLmInputShape cfg])
(β := m (_root_.Runtime.Autograd.TorchLean.NN.Seq.RefT (m := m) (α := α)
Spec.Shape.scalar))
(fun args => do
let (ps, ins) :=
_root_.Runtime.Autograd.Torch.RefList.split
(Ref := _root_.Runtime.Autograd.TorchLean.NN.Seq.RefT (m := m) (α := α))
(ss₁ := (.dim cfg.vocab (.dim cfg.dModel .scalar)) :: paramShapes body)
(ss₂ := [causalTokenLmInputShape cfg, causalTokenLmInputShape cfg]) args
let .cons xFloat (.cons yFloat .nil) := ins
let .cons tokenEmbedding bodyParams := ps
let tokens ← _root_.Runtime.Autograd.TorchLean.F.floatVecToNatTensor (m := m) (α := α)
(k := cfg.batch * cfg.seqLen) xFloat
let targets ← _root_.Runtime.Autograd.TorchLean.F.floatVecToNatTensor (m := m) (α := α)
(k := cfg.batch * cfg.seqLen) yFloat
let x ← _root_.Runtime.Autograd.TorchLean.F.embeddingBatchSeqNat (m := m) (α := α)
(vocab := cfg.vocab) (dim := cfg.dModel) (batch := cfg.batch)
(seqLen := cfg.seqLen) tokenEmbedding tokens
let logits ← _root_.Runtime.Autograd.TorchLean.NN.Seq.evalParams
(model := body) (α := α) (m := m) mode bodyParams x
let logitsRows ← _root_.Runtime.Autograd.Torch.reshape (m := m) (α := α)
(s₁ := .dim cfg.batch (.dim cfg.seqLen (.dim cfg.vocab .scalar)))
(s₂ := .dim (cfg.batch * cfg.seqLen) (.dim cfg.vocab .scalar))
logits (by
simp [_root_.Spec.Shape.size, Nat.mul_assoc])
_root_.Runtime.Autograd.TorchLean.Loss.crossEntropyRowsNat (m := m) (α := α)
(rows := cfg.batch * cfg.seqLen) (classes := cfg.vocab)
logitsRows targets (reduction := reduction)) }

/-- Training-mode wrapper for float-encoded token-id causal language modeling. -/
def causalTransformerTokenLmScalarModuleDef (cfg : CausalOneHotConfig)
(body : nn.Sequential (causalEmbeddingShape cfg) (causalOneHotShape cfg))
(initParams : _root_.Runtime.Autograd.Torch.TList Float
((.dim cfg.vocab (.dim cfg.dModel .scalar)) :: paramShapes body))
(reduction : TorchLean.Loss.Reduction := .mean) :
TorchLean.Module.ScalarModuleDef
((.dim cfg.vocab (.dim cfg.dModel .scalar)) :: paramShapes body)
[causalTokenLmInputShape cfg, causalTokenLmInputShape cfg] :=
causalTransformerTokenLmScalarModuleDefWithMode .train cfg body initParams (reduction := reduction)

end models
end nn

Expand Down
41 changes: 41 additions & 0 deletions NN/API/Public/Facade/Data/Text.lean
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,47 @@ def causalLmOneHotSampleRowsFromTokenArray
text.Corpus.randomBatchTokenWindows tokens batch seqLen seed step (padId := padId)
causalLmOneHotSampleRows (α := α) batch seqLen vocab idsAt (padId := padId)

/--
Flatten one `(seqLen + 1)` token window into causal-LM `(x, y)` id lists.
-/
def causalLmTokenRows (seqLen : Nat) (window : List Nat) (padId : Nat := 0) :
List Nat × List Nat :=
let x := (List.range seqLen).map (fun i => window.getD i padId)
let y := (List.range seqLen).map (fun i => window.getD (i + 1) padId)
(x, y)

/--
Build a float tensor of integer token ids from a flat `List Nat` of length `batch * seqLen`.
-/
def causalLmTokenFloatVec {α : Type} [Runtime.SemanticScalar α] [Runtime.Scalar α]
(batch seqLen : Nat) (tokens : List Nat) :
Tensor.T α (.dim (batch * seqLen) .scalar) :=
let xF : _root_.Spec.Tensor Float (.dim (batch * seqLen) .scalar) :=
_root_.Spec.Tensor.dim (fun i : Fin (batch * seqLen) =>
_root_.Spec.Tensor.scalar (Float.ofNat (tokens.getD i.val 0)))
Tensor.castFloat Runtime.ofFloat xF

/--
Build a batched token-id causal-language-model sample from an array-backed corpus.

Token ids are passed as float inputs so the training loop can swap windows each step without
re-instantiating the scalar module.
-/
def causalLmTokenSampleRowsFromTokenArray
{α : Type} [Runtime.SemanticScalar α] [Runtime.Scalar α]
(batch seqLen : Nat) (tokens : Array Nat) (seed step : Nat) (padId : Nat := 0) :
SupervisedSample α (.dim (batch * seqLen) .scalar) (.dim (batch * seqLen) .scalar) :=
let idsAt :=
text.Corpus.randomBatchTokenWindows tokens batch seqLen seed step (padId := padId)
let (xList, yList) :=
(List.finRange batch).foldl (fun (acc : List Nat × List Nat) bi =>
let (xs, ys) := acc
let (xRow, yRow) := causalLmTokenRows seqLen (idsAt bi) (padId := padId)
(xs ++ xRow, ys ++ yRow)) ([], [])
NN.API.Sample.mk
(causalLmTokenFloatVec (α := α) batch seqLen xList)
(causalLmTokenFloatVec (α := α) batch seqLen yList)

/--
Build one unbatched one-hot causal-language-model sample directly from a token list.

Expand Down
4 changes: 3 additions & 1 deletion NN/API/Runtime/Module.lean
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ Default dtype policy:
- If the user does not specify `--dtype` / `--float32-mode` and `--cuda` is present, default to
`dtype=float` (CUDA eager supports `Float` upload/download).
- Otherwise default to `dtype=float32` (executable IEEE-754 float32 semantics).

When `--cuda` is selected, also enable the CUDA fast path (`--fast-kernels`) by default.
-/
def parseAndStripWithDefaultDType (args : List String) (defaultDType : DType) :
Except String (ExecConfig × List String) := do
Expand All @@ -214,7 +216,7 @@ def parseAndStripWithDefaultDType (args : List String) (defaultDType : DType) :
dtype := dtype,
backend := backend,
useGpu := useGpu,
fastKernels := fastKernels,
fastKernels := fastKernels || useGpu,
fastGpuMatmulPrecision := fastGpuMatmulPrecision
}, rest)

Expand Down
13 changes: 13 additions & 0 deletions NN/Runtime/Autograd/Torch/Core/Functional.lean
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,14 @@ class Ops (m : Type → Type) (α : Type) [Context α] [DecidableEq Shape] where
gatherRowsNat {rows cols k : Nat} :
Ref (.dim rows (.dim cols .scalar)) → Tensor Nat (.dim k .scalar) → m (Ref (.dim k (.dim cols
.scalar)))
/--
Read a float vector of integer-valued token ids and return a `Tensor Nat` index vector.

This is non-differentiable: gradients do not flow back into the float input. Language-model
benchmarks pass token ids as float inputs so each step can supply a fresh window without
re-instantiating the module.
-/
floatVecToNatTensor {k : Nat} : Ref (.dim k .scalar) → m (Tensor Nat (.dim k .scalar))
scatterAddVec {n : Nat} : Ref (.dim n .scalar) → Ref Shape.scalar → Fin n → m (Ref (.dim n
.scalar))
scatterAddRow {rows cols : Nat} :
Expand Down Expand Up @@ -457,6 +465,11 @@ def gatherRowsNat {rows cols k : Nat}
(x : Ref (m := m) (α := α) (.dim rows (.dim cols .scalar))) (idx : Tensor Nat (.dim k .scalar)) :
m (Ref (m := m) (α := α) (.dim k (.dim cols .scalar))) :=
Ops.gatherRowsNat (m := m) (α := α) (rows := rows) (cols := cols) (k := k) x idx
/-- Convert a float vector of integer token ids to `Tensor Nat` (non-differentiable). -/
def floatVecToNatTensor {k : Nat}
(x : Ref (m := m) (α := α) (.dim k .scalar)) :
m (Tensor Nat (.dim k .scalar)) :=
Ops.floatVecToNatTensor (m := m) (α := α) (k := k) x
/-- Re-export of `Ops.scatter_add_vec`. -/
def scatterAddVec {n : Nat}
(x : Ref (m := m) (α := α) (.dim n .scalar)) (v : Ref (m := m) (α := α) Shape.scalar) (i : Fin n)
Expand Down
18 changes: 18 additions & 0 deletions NN/Runtime/Autograd/Torch/Core/Ops.lean
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,24 @@ def gatherRowsNat {α : Type} (s : EagerSession α) [Add α] [Zero α] [Decidabl
pure (some { id := id })
dispatchCudaOpt (α := α) s "gather_rows_nat" cpu cuda

/--
Read a float input vector and return the corresponding `Tensor Nat` index vector.

Non-differentiable: used by token-id language-model losses that accept float-encoded ids as inputs.
-/
def floatVecToNatTensor {α : Type} (s : EagerSession α) [CudaBridge.TensorConv α] [DecidableEq Shape]
{k : Nat} (x : TensorRef α (.dim k .scalar)) : IO (Tensor Nat (.dim k .scalar)) := do
let v ← getValue (α := α) s (sh := .dim k .scalar) x
match v with
| .dim f =>
let ns ← (List.finRange k).mapM (fun i => do
match f i with
| .scalar fl => do
let ff ← CudaBridge.TensorConv.toFloat (α := α) fl
pure (UInt64.toNat (Float.floor ff).toUInt64))
pure <|
Tensor.dim (fun i => Tensor.scalar (ns.getD i.val 0))

/-- Gather `k` scalars using indices stored in the nat-environment (`NatVecRef`). -/
def gatherVecRef {α : Type} (s : EagerSession α) [Add α] [Zero α] [DecidableEq Shape]
{n k : Nat} (x : TensorRef α (.dim n .scalar)) (idx : NatVecRef k) :
Expand Down
4 changes: 4 additions & 0 deletions NN/Runtime/Autograd/Torch/Core/Trainer.lean
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ instance {α : Type} [Context α] [Internal.CudaBridge.TensorConv α] [Decidable
Internal.EagerSession.gatherVecNat (α := α) sess (n := n) (k := k) x idx
gatherRowsNat := fun {rows cols k} x idx => fun sess =>
Internal.EagerSession.gatherRowsNat (α := α) sess (rows := rows) (cols := cols) (k := k) x idx
floatVecToNatTensor := fun {k} x => fun sess =>
Internal.EagerSession.floatVecToNatTensor (α := α) sess (k := k) x
scatterAddVec := fun {n} x v i => fun sess =>
Internal.EagerSession.scatterAddVec (α := α) sess (n := n) x v i
scatterAddRow := fun {rows cols} x v i => fun sess =>
Expand Down Expand Up @@ -245,6 +247,8 @@ instance {α : Type} [Context α] [DecidableEq Shape] {Γ : List Shape} :
gatherRowsNat := fun {rows cols k} x idx =>
Runtime.Autograd.Compiled.GraphM.gatherRowsNat (α := α) (Γ := Γ) (rows := rows) (cols := cols)
(k := k) x idx
floatVecToNatTensor := fun {_k} _x =>
throw "compiled GraphM: floatVecToNatTensor requires eager backend (dynamic token ids)"
scatterAddVec := fun {n} x v i =>
Runtime.Autograd.Compiled.GraphM.scatterAddVec (α := α) (Γ := Γ) (n := n) x v i
scatterAddRow := fun {rows cols} x v i =>
Expand Down
6 changes: 6 additions & 0 deletions NN/Runtime/Autograd/Torch/Core/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ structure Options where
-/
fastGpuMatmulPrecision : Runtime.Autograd.FastKernels.GpuMatmulPrecision := .fp32
/--
Skip autograd during backward (inference-style evaluation).

Used by `eval1NoGrad` and similar helpers that run forward without recording gradients.
-/
noGrad : Bool := false
/--
Eager execution on CUDA.

When `true` and `backend = .eager`, the eager session uses the CUDA tape
Expand Down
7 changes: 7 additions & 0 deletions NN/Runtime/Autograd/TorchLean/Functional/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,13 @@ def embeddingBatchSeqNat {α : Type} [Context α] [DecidableEq Shape]
gathered (by
simp [Shape.size, Nat.mul_assoc])

/-- Read float-encoded token ids as a `Tensor Nat` index vector (non-differentiable). -/
def floatVecToNatTensor {α : Type} [Context α] [DecidableEq Shape]
{m : TypeType} [Monad m] [Ops (m := m) (α := α)] {k : Nat}
(x : RefTy (m := m) (α := α) (.dim k .scalar)) :
m (Tensor Nat (.dim k .scalar)) :=
_root_.Runtime.Autograd.Torch.floatVecToNatTensor (m := m) (α := α) (k := k) x

/-! ## Reductions -/

/--
Expand Down
1 change: 1 addition & 0 deletions NN/Runtime/Autograd/TorchLean/NN/Seq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def programWithMode {σ τ : Shape} (mode : Mode) (model : Seq σ τ)
[_root_.Runtime.Autograd.Torch.Internal.CudaBridge.TensorConv α]
(params : _root_.Runtime.Autograd.Torch.ParamList α (paramShapes model))
(x : Spec.Tensor α σ) : IO (Spec.Tensor α τ) := do
let opts := { opts with noGrad := true }
let sess ← _root_.Runtime.Autograd.Torch.Internal.EagerSession.new (α := α) opts
sess.resetTape
let outRef ← (do
Expand Down
2 changes: 2 additions & 0 deletions NN/Verification/TorchLean/Compile.lean
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ instance {α : Type} [Context α] [DecidableEq Shape] :
fail (α := α) "TorchLean→IR: gather is outside the verifier IR fragment"
gatherRowsNat := fun {_rows _cols _k} _x _idx =>
fail (α := α) "TorchLean→IR: gather is outside the verifier IR fragment"
floatVecToNatTensor := fun {_k} _x =>
fail (α := α) "TorchLean→IR: float_vec_to_nat_tensor is outside the verifier IR fragment"

scatterAddVec := fun {_n} _x _val _i =>
fail (α := α) "TorchLean→IR: scatter is outside the verifier IR fragment"
Expand Down
2 changes: 2 additions & 0 deletions NN/Verification/TorchLean/SpecEval.lean
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ instance {α : Type} [Context α] [DecidableEq Shape] : Runtime.Autograd.Torch.O
"TorchLeanSpecEval: gather_vec_nat not supported in spec backend"
gatherRowsNat := fun {_rows _cols _k} _x _idx => throw
"TorchLeanSpecEval: gather_rows_nat not supported in spec backend"
floatVecToNatTensor := fun {_k} _x => throw
"TorchLeanSpecEval: float_vec_to_nat_tensor not supported in spec backend"
scatterAddVec := fun {_n} _x _val _i => throw
"TorchLeanSpecEval: scatter_add_vec not supported in spec backend"
scatterAddRow := fun {_rows _cols} _x _row _i => throw
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ lake exe torchlean mlp --steps 10 --dtype float --backend eager

# Optional CUDA run, if the CUDA toolkit and an NVIDIA GPU are available:
lake build -K cuda=true
lake exe -K cuda=true torchlean mlp --cuda --fast-kernels --steps 1000
lake exe -K cuda=true torchlean mlp --cuda --steps 1000
```

The first MLP command uses the executable IEEE-style Float32 path. The second
uses Lean's builtin `Float` runtime path. The CUDA command uses the native GPU
runtime path and checks that the CUDA backend is available; it is not a trusted
proof boundary.
proof boundary. Handwritten fused CUDA kernels are opt-in performance paths via
`--fast-kernels`; keep them off unless you are explicitly testing that backend.

TorchLean is pinned by `lean-toolchain` and currently builds with
`leanprover/lean4:v4.31.0`.
Expand Down
8 changes: 8 additions & 0 deletions csrc/cuda/common/torchlean_cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ static inline void checkCuda(cudaError_t e, const char* msg) {
}
}

static inline void torchlean_cuda_clear_pending_error() {
(void)cudaGetLastError();
}

static inline void torchlean_cuda_check_launch(const char* msg) {
checkCuda(cudaGetLastError(), msg);
}

static inline void torchlean_cuda_free_checked(void** ptr, const char* msg) {
if (ptr && *ptr) {
checkCuda(cudaFree(*ptr), msg);
Expand Down
9 changes: 6 additions & 3 deletions csrc/cuda/tensor/torchlean_cuda_tensor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -891,8 +891,9 @@ extern "C" LEAN_EXPORT lean_obj_res torchlean_cuda_buffer_to_float_array(b_lean_
} \
dim3 blocks = torchlean_blocks_for(b->size); \
dim3 threads = dim3(kBlockSize); \
torchlean_cuda_clear_pending_error(); \
KERNEL<<<blocks, threads>>>(b->data, out->data, b->size, (float)c); \
checkCuda(cudaGetLastError(), "cuda " LABEL " kernel launch failed"); \
torchlean_cuda_check_launch("cuda " LABEL " kernel launch failed"); \
return torchlean_cuda_buffer_box(out); \
}

Expand Down Expand Up @@ -1114,13 +1115,15 @@ extern "C" LEAN_EXPORT lean_obj_res torchlean_cuda_buffer_reduce_mean(b_lean_obj
"cudaMemcpy reduceMean init failed");
dim3 blocks = torchlean_blocks_for(b->size);
dim3 threads = dim3(kBlockSize);
torchlean_cuda_clear_pending_error();
torchlean_reduce_sum_f32<<<blocks, threads>>>(b->data, out->data, b->size);
checkCuda(cudaGetLastError(), "cuda reduceMean reduce kernel launch failed");
torchlean_cuda_check_launch("cuda reduceMean reduce kernel launch failed");
}

float scale = 1.0f / (float)b->size;
torchlean_cuda_clear_pending_error();
torchlean_scale1_f32<<<dim3(1), dim3(1)>>>(out->data, scale);
checkCuda(cudaGetLastError(), "cuda reduceMean scale kernel launch failed");
torchlean_cuda_check_launch("cuda reduceMean scale kernel launch failed");

return torchlean_cuda_buffer_box(out);
}
Loading