commit 2c30e06f206299aee3b1603393571d6e4ddd61de Author: Soma Nakamura
Date: Thu Jul 10 18:09:14 2025 +0900
initial
diff --git a/.devenv.flake.nix b/.devenv.flake.nix
new file mode 100644
index 0000000..b7b7adc
--- /dev/null
+++ b/.devenv.flake.nix
@@ -0,0 +1,163 @@
+{
+ inputs =
+ let
+ version = "1.6.1";
+system = "x86_64-linux";
+devenv_root = "/home/centra/dev/pnn/progressive-llm-training";
+devenv_dotfile = ./.devenv;
+devenv_dotfile_string = ".devenv";
+container_name = null;
+devenv_tmpdir = "/run/user/1000";
+devenv_runtime = "/run/user/1000/devenv-adeda32";
+devenv_istesting = false;
+devenv_direnvrc_latest_version = 1;
+
+ in {
+ git-hooks.url = "github:cachix/git-hooks.nix";
+ git-hooks.inputs.nixpkgs.follows = "nixpkgs";
+ pre-commit-hooks.follows = "git-hooks";
+ nixpkgs.url = "github:cachix/devenv-nixpkgs/rolling";
+ devenv.url = "github:cachix/devenv?dir=src/modules";
+ } // (if builtins.pathExists (devenv_dotfile + "/flake.json")
+ then builtins.fromJSON (builtins.readFile (devenv_dotfile + "/flake.json"))
+ else { });
+
+ outputs = { nixpkgs, ... }@inputs:
+ let
+ version = "1.6.1";
+system = "x86_64-linux";
+devenv_root = "/home/centra/dev/pnn/progressive-llm-training";
+devenv_dotfile = ./.devenv;
+devenv_dotfile_string = ".devenv";
+container_name = null;
+devenv_tmpdir = "/run/user/1000";
+devenv_runtime = "/run/user/1000/devenv-adeda32";
+devenv_istesting = false;
+devenv_direnvrc_latest_version = 1;
+
+ devenv =
+ if builtins.pathExists (devenv_dotfile + "/devenv.json")
+ then builtins.fromJSON (builtins.readFile (devenv_dotfile + "/devenv.json"))
+ else { };
+ getOverlays = inputName: inputAttrs:
+ map
+ (overlay:
+ let
+ input = inputs.${inputName} or (throw "No such input `${inputName}` while trying to configure overlays.");
+ in
+ input.overlays.${overlay} or (throw "Input `${inputName}` has no overlay called `${overlay}`. Supported overlays: ${nixpkgs.lib.concatStringsSep ", " (builtins.attrNames input.overlays)}"))
+ inputAttrs.overlays or [ ];
+ overlays = nixpkgs.lib.flatten (nixpkgs.lib.mapAttrsToList getOverlays (devenv.inputs or { }));
+ pkgs = import nixpkgs {
+ inherit system;
+ config = {
+ allowUnfree = devenv.allowUnfree or false;
+ allowBroken = devenv.allowBroken or false;
+ permittedInsecurePackages = devenv.permittedInsecurePackages or [ ];
+ };
+ inherit overlays;
+ };
+ lib = pkgs.lib;
+ importModule = path:
+ if lib.hasPrefix "./" path
+ then if lib.hasSuffix ".nix" path
+ then ./. + (builtins.substring 1 255 path)
+ else ./. + (builtins.substring 1 255 path) + "/devenv.nix"
+ else if lib.hasPrefix "../" path
+ then throw "devenv: ../ is not supported for imports"
+ else
+ let
+ paths = lib.splitString "/" path;
+ name = builtins.head paths;
+ input = inputs.${name} or (throw "Unknown input ${name}");
+ subpath = "/${lib.concatStringsSep "/" (builtins.tail paths)}";
+ devenvpath = "${input}" + subpath;
+ devenvdefaultpath = devenvpath + "/devenv.nix";
+ in
+ if lib.hasSuffix ".nix" devenvpath
+ then devenvpath
+ else if builtins.pathExists devenvdefaultpath
+ then devenvdefaultpath
+ else throw (devenvdefaultpath + " file does not exist for input ${name}.");
+ project = pkgs.lib.evalModules {
+ specialArgs = inputs // { inherit inputs; };
+ modules = [
+ ({ config, ... }: {
+ _module.args.pkgs = pkgs.appendOverlays (config.overlays or [ ]);
+ })
+ (inputs.devenv.modules + /top-level.nix)
+ {
+ devenv.cliVersion = version;
+ devenv.root = devenv_root;
+ devenv.dotfile = devenv_root + "/" + devenv_dotfile_string;
+ }
+ (pkgs.lib.optionalAttrs (inputs.devenv.isTmpDir or false) {
+ devenv.tmpdir = devenv_tmpdir;
+ devenv.runtime = devenv_runtime;
+ })
+ (pkgs.lib.optionalAttrs (inputs.devenv.hasIsTesting or false) {
+ devenv.isTesting = devenv_istesting;
+ })
+ (pkgs.lib.optionalAttrs (container_name != null) {
+ container.isBuilding = pkgs.lib.mkForce true;
+ containers.${container_name}.isBuilding = true;
+ })
+ ({ options, ... }: {
+ config.devenv = pkgs.lib.optionalAttrs (builtins.hasAttr "direnvrcLatestVersion" options.devenv) {
+ direnvrcLatestVersion = devenv_direnvrc_latest_version;
+ };
+ })
+ ] ++ (map importModule (devenv.imports or [ ])) ++ [
+ (if builtins.pathExists ./devenv.nix then ./devenv.nix else { })
+ (devenv.devenv or { })
+ (if builtins.pathExists ./devenv.local.nix then ./devenv.local.nix else { })
+ (if builtins.pathExists (devenv_dotfile + "/cli-options.nix") then import (devenv_dotfile + "/cli-options.nix") else { })
+ ];
+ };
+ config = project.config;
+
+ options = pkgs.nixosOptionsDoc {
+ options = builtins.removeAttrs project.options [ "_module" ];
+ warningsAreErrors = false;
+ # Unpack Nix types, e.g. literalExpression, mDoc.
+ transformOptions =
+ let isDocType = v: builtins.elem v [ "literalDocBook" "literalExpression" "literalMD" "mdDoc" ];
+ in lib.attrsets.mapAttrs (_: v:
+ if v ? _type && isDocType v._type then
+ v.text
+ else if v ? _type && v._type == "derivation" then
+ v.name
+ else
+ v
+ );
+ };
+
+ build = options: config:
+ lib.concatMapAttrs
+ (name: option:
+ if builtins.hasAttr "type" option then
+ if option.type.name == "output" || option.type.name == "outputOf" then {
+ ${name} = config.${name};
+ } else { }
+ else
+ let v = build option config.${name};
+ in if v != { } then {
+ ${name} = v;
+ } else { }
+ )
+ options;
+
+ systems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ];
+ in
+ {
+ devShell = lib.genAttrs systems (system: config.shell);
+ packages = lib.genAttrs systems (system: {
+ optionsJSON = options.optionsJSON;
+ # deprecated
+ inherit (config) info procfileScript procfileEnv procfile;
+ ci = config.ciDerivation;
+ });
+ devenv = config;
+ build = build project.options project.config;
+ };
+ }
diff --git a/.devenv/bash b/.devenv/bash
new file mode 120000
index 0000000..3eab571
--- /dev/null
+++ b/.devenv/bash
@@ -0,0 +1 @@
+/nix/store/94lg0shvsfc845zy8gnflvpqxxiyijbz-bash-interactive-5.2p37
\ No newline at end of file
diff --git a/.devenv/devenv.json b/.devenv/devenv.json
new file mode 100644
index 0000000..bfa79af
--- /dev/null
+++ b/.devenv/devenv.json
@@ -0,0 +1 @@
+{"inputs":{"nixpkgs":{"url":"github:NixOS/nixpkgs/nixos-unstable"},"nixpkgs-python":{"url":"github:cachix/nixpkgs-python","inputs":{"nixpkgs":{"follows":"nixpkgs"}}}},"allowUnfree":true}
\ No newline at end of file
diff --git a/.devenv/flake.json b/.devenv/flake.json
new file mode 100644
index 0000000..c487dcb
--- /dev/null
+++ b/.devenv/flake.json
@@ -0,0 +1 @@
+{"nixpkgs":{"url":"github:NixOS/nixpkgs/nixos-unstable"},"nixpkgs-python":{"url":"github:cachix/nixpkgs-python","inputs":{"nixpkgs":{"follows":"nixpkgs"}}}}
\ No newline at end of file
diff --git a/.devenv/gc/shell b/.devenv/gc/shell
new file mode 120000
index 0000000..2b5306e
--- /dev/null
+++ b/.devenv/gc/shell
@@ -0,0 +1 @@
+shell-1-link
\ No newline at end of file
diff --git a/.devenv/gc/shell-1-link b/.devenv/gc/shell-1-link
new file mode 120000
index 0000000..eacdc2d
--- /dev/null
+++ b/.devenv/gc/shell-1-link
@@ -0,0 +1 @@
+/nix/store/7fimdw1in7f1g0wxw5cr9pg26rs4rp5g-devenv-shell-env
\ No newline at end of file
diff --git a/.devenv/imports.txt b/.devenv/imports.txt
new file mode 100644
index 0000000..e69de29
diff --git a/.devenv/input-paths.txt b/.devenv/input-paths.txt
new file mode 100644
index 0000000..6d1c4e8
--- /dev/null
+++ b/.devenv/input-paths.txt
@@ -0,0 +1,11 @@
+/home/centra/.config/nixpkgs/config.nix
+/home/centra/.config/nixpkgs/overlays
+/home/centra/.config/nixpkgs/overlays.nix
+/home/centra/.nixpkgs/config.nix
+/home/centra/dev/pnn/progressive-llm-training/.devenv/flake.json
+/home/centra/dev/pnn/progressive-llm-training/.devenv.flake.nix
+/home/centra/dev/pnn/progressive-llm-training/.env
+/home/centra/dev/pnn/progressive-llm-training/devenv.local.nix
+/home/centra/dev/pnn/progressive-llm-training/devenv.lock
+/home/centra/dev/pnn/progressive-llm-training/devenv.nix
+/home/centra/dev/pnn/progressive-llm-training/devenv.yaml
\ No newline at end of file
diff --git a/.devenv/load-exports b/.devenv/load-exports
new file mode 100755
index 0000000..c0b1498
--- /dev/null
+++ b/.devenv/load-exports
@@ -0,0 +1,3 @@
+export PATH='/home/centra/dev/pnn/progressive-llm-training/.devenv/state/venv/bin:/nix/store/bdqwd2frn9m7n3hj2436s0vlnv7mawpc-python3-3.11.13-env/bin:/nix/store/9w80x8njl1hcp8vlk1f3x17q4hcd2cqp-evaluate/bin:/nix/store/8df6wqahd2fqzl04kcs3xs32yqqimcxb-install-packages/bin:/nix/store/v5rz1h6ci23icfp6y228r2m0fqrdf408-install-packages-cpu/bin:/nix/store/69142b4sjmb4jffmyjb8nar6qzlgxnpg-prepare-data/bin:/nix/store/bhb6l6yfqknnwc7y5j5xc9k866hajv7b-train/bin:/nix/store/pbqah1qk4b5y14fqinr1h8zvhqy71v81-gcc-wrapper-14.3.0/bin:/nix/store/sa7j7cddyblhcb3ch3ds10w7nw75yjj1-gcc-14.3.0/bin:/nix/store/mdmsnfcvxyk5ynz7nx8nhss1wig0gljx-glibc-2.40-66-bin/bin:/nix/store/psy9v2asypgl9ylg8cnzkixc7fv0snj0-coreutils-9.7/bin:/nix/store/cadx5p7c0i06gf6h84iw9mrhx56imbv0-binutils-wrapper-2.44/bin:/nix/store/z3za8hfc24wb117s50p8b10agjkgm039-binutils-2.44/bin:/nix/store/dx4bdrs7mq3jfviqhszrc7l35ps9kg64-cmake-3.31.7/bin:/nix/store/1492q00cm64n0hs5966s8cqj6j0x5nxg-ninja-1.12.1/bin:/nix/store/h5khrpnjj3fb182sc32fx1z75w0lhksy-pkg-config-wrapper-0.29.2/bin:/nix/store/rzqvhv48m3nh8g3j4k6jmz6yqy8apr95-git-2.49.0/bin:/nix/store/nygfbkv0j6fvwwa82mdwxm4qfiq3p4q2-git-lfs-3.6.1/bin:/nix/store/fir4g1m8dvg46mh8silh3wnmm9mc0jix-htop-3.4.1/bin:/nix/store/9mc2m4sacbk4l7sc4w7m08m1x9bf5dgn-tmux-3.5a/bin:/nix/store/cxy72qdk41k3zjs5fw1nw1whv6wf7hv2-vim-9.1.1401/bin:/nix/store/74k8qwbfa6lm8psm2vjh2vj04fpr6c5g-openssl-3.4.1-bin/bin:/nix/store/m9k83ip1yx29xs94sa5x8j70s2vfgj6i-glib-2.84.2-dev/bin:/nix/store/zs5crhr67zp8cxn7dh4mwq08zw3sb31m-gettext-0.22.5/bin:/nix/store/rklrz4rwi03hxvz0kwh75vz55wb9b1qz-glib-2.84.2-bin/bin:/nix/store/xbpwk3xzanxj12157byj6wjagm2wfb3c-cuda-merged-12.8/bin:/nix/store/v0zrnzl3anb71ma5c2kx71dl8kyh0rf6-cuda_cuobjdump-12.8.90-bin/bin:/nix/store/v4mm21f67qki6ss6mqp3anlmaiw0r1zd-pre-commit-bin/bin:/nix/store/mq2i9br9h890bnahlds9jnff1jf6xjpb-python3.13-black-25.1.0/bin:/nix/store/sd81bvmch7njdpwx3lkjslixcbj5mivz-python3-3.13.4/bin:/nix/store/mdzm1l0rnpwp8ha0mbxll0db4r2p0xj3-python3.13-flake8-7.2.0/bin:/nix/store/xs72vlx7i6snrrrqx2zn529fbbqrwlwq-python3.13-pycodestyle-2.13.0/bin:/nix/store/5a8m3p0svp6myq1cz4ww431fsbh3xrg5-python3.13-pyflakes-3.3.2/bin:/nix/store/p6bch581drrxv3dm7vwxqazpbssjz4nv-python3.13-mypy-1.15.0/bin:/nix/store/1c8sm86wj45vwkb3ww2b870h9i9wna6r-patchelf-0.15.0/bin:/nix/store/psy9v2asypgl9ylg8cnzkixc7fv0snj0-coreutils-9.7/bin:/nix/store/c14zwgl8hf1wm0izij2i16xvk8ak70cy-findutils-4.10.0/bin:/nix/store/ibx4jfwlhjg4g0s6rrxrpaxa3ka8ns4m-diffutils-3.12/bin:/nix/store/pr318zsl44jdwpk9wk0sdrn19b6in7ah-gnused-4.9/bin:/nix/store/bc6zxzjnkjp4r9nhz5imy3cypvdh6r4n-gnugrep-3.12/bin:/nix/store/nv3y7zb1cwz1h9qy7nwz0s54j8dl1kqj-gawk-5.3.2/bin:/nix/store/lp82dcnrzljyix6yigwzrlpr1smvpmb0-gnutar-1.35/bin:/nix/store/6ag5dhk7sma61p6vl0kazfmpbrq08nqh-gzip-1.14/bin:/nix/store/ykdv4id6893gmkqwdmbimq237c1xqvq7-bzip2-1.0.8-bin/bin:/nix/store/6bwp1y45zlyvpr4ja2sk1yi9v5mrs94x-gnumake-4.4.1/bin:/nix/store/00zrahbb32nzawrmv9sjxn36h7qk9vrs-bash-5.2p37/bin:/nix/store/c9xmgszbf6i4dfq9r953khk9d7fdqigw-patch-2.8/bin:/nix/store/ikfwx7kbwz9zr7fziiac7f57jgbh3bnv-xz-5.8.1-bin/bin:/nix/store/3pdmbqy86wsbjdazxv1n3vrmj60vn0ri-file-5.45/bin:/run/wrappers/bin:/home/centra/.local/share/flatpak/exports/bin:/var/lib/flatpak/exports/bin:/home/centra/.nix-profile/bin:/nix/profile/bin:/home/centra/.local/state/nix/profile/bin:/etc/profiles/per-user/centra/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin'
+export VIRTUAL_ENV=/home/centra/dev/pnn/progressive-llm-training/.devenv/state/venv
+
diff --git a/.devenv/nix-eval-cache.db b/.devenv/nix-eval-cache.db
new file mode 100644
index 0000000..7ee7c11
Binary files /dev/null and b/.devenv/nix-eval-cache.db differ
diff --git a/.devenv/nix-eval-cache.db-shm b/.devenv/nix-eval-cache.db-shm
new file mode 100644
index 0000000..206cf3b
Binary files /dev/null and b/.devenv/nix-eval-cache.db-shm differ
diff --git a/.devenv/nix-eval-cache.db-wal b/.devenv/nix-eval-cache.db-wal
new file mode 100644
index 0000000..aea3837
Binary files /dev/null and b/.devenv/nix-eval-cache.db-wal differ
diff --git a/.devenv/profile b/.devenv/profile
new file mode 120000
index 0000000..0a7733c
--- /dev/null
+++ b/.devenv/profile
@@ -0,0 +1 @@
+/nix/store/y2vscmx3lckyzyag6xg8b02pkdsk326d-devenv-profile
\ No newline at end of file
diff --git a/.devenv/run b/.devenv/run
new file mode 120000
index 0000000..5d29d76
--- /dev/null
+++ b/.devenv/run
@@ -0,0 +1 @@
+/run/user/1000/devenv-adeda32
\ No newline at end of file
diff --git a/.devenv/state/git-hooks/config.json b/.devenv/state/git-hooks/config.json
new file mode 100644
index 0000000..be68384
--- /dev/null
+++ b/.devenv/state/git-hooks/config.json
@@ -0,0 +1 @@
+{configPath:.pre-commit-config.yaml}
diff --git a/.devenv/tasks.db b/.devenv/tasks.db
new file mode 100644
index 0000000..7ee7c11
Binary files /dev/null and b/.devenv/tasks.db differ
diff --git a/.devenv/tasks.db-shm b/.devenv/tasks.db-shm
new file mode 100644
index 0000000..f5895a2
Binary files /dev/null and b/.devenv/tasks.db-shm differ
diff --git a/.devenv/tasks.db-wal b/.devenv/tasks.db-wal
new file mode 100644
index 0000000..8411545
Binary files /dev/null and b/.devenv/tasks.db-wal differ
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..20b8504
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,32 @@
+# Python
+__pycache__/
+*.py[cod]
+*$py.class
+*.so
+.Python
+venv/
+ENV/
+env/
+.venv/
+
+# Nix
+result
+result-*
+
+# Project specific
+outputs/
+data/
+*.log
+wandb/
+.ipynb_checkpoints/
+*.pt
+*.pth
+*.bin
+*.safetensors
+
+# IDE
+.vscode/
+.idea/
+*.swp
+*.swo
+*~
\ No newline at end of file
diff --git a/=2.5.0 b/=2.5.0
new file mode 100644
index 0000000..0d0eafa
--- /dev/null
+++ b/=2.5.0
@@ -0,0 +1,33 @@
+Collecting flash-attn
+ Using cached flash_attn-2.8.0.post2-cp311-cp311-linux_x86_64.whl
+Requirement already satisfied: torch in ./.devenv/state/venv/lib/python3.11/site-packages (from flash-attn) (2.7.1+cu128)
+Collecting einops (from flash-attn)
+ Using cached einops-0.8.1-py3-none-any.whl.metadata (13 kB)
+Requirement already satisfied: filelock in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (3.13.1)
+Requirement already satisfied: typing-extensions>=4.10.0 in /nix/store/x74hdbjsz4ck98w8lyxv8kkwxs1wm2il-python3.13-typing-extensions-4.13.2/lib/python3.13/site-packages (from torch->flash-attn) (4.13.2)
+Requirement already satisfied: sympy>=1.13.3 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (1.13.3)
+Requirement already satisfied: networkx in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (3.3)
+Requirement already satisfied: jinja2 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (3.1.4)
+Requirement already satisfied: fsspec in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (2024.6.1)
+Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.61 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.61)
+Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.57 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.57)
+Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.57 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.57)
+Requirement already satisfied: nvidia-cudnn-cu12==9.7.1.26 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (9.7.1.26)
+Requirement already satisfied: nvidia-cublas-cu12==12.8.3.14 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.3.14)
+Requirement already satisfied: nvidia-cufft-cu12==11.3.3.41 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (11.3.3.41)
+Requirement already satisfied: nvidia-curand-cu12==10.3.9.55 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (10.3.9.55)
+Requirement already satisfied: nvidia-cusolver-cu12==11.7.2.55 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (11.7.2.55)
+Requirement already satisfied: nvidia-cusparse-cu12==12.5.7.53 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.5.7.53)
+Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (0.6.3)
+Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (2.26.2)
+Requirement already satisfied: nvidia-nvtx-cu12==12.8.55 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.55)
+Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.61 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.61)
+Requirement already satisfied: nvidia-cufile-cu12==1.13.0.11 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (1.13.0.11)
+Requirement already satisfied: triton==3.3.1 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (3.3.1)
+Requirement already satisfied: setuptools>=40.8.0 in ./.devenv/state/venv/lib/python3.11/site-packages (from triton==3.3.1->torch->flash-attn) (80.9.0)
+Requirement already satisfied: mpmath<1.4,>=1.1.0 in ./.devenv/state/venv/lib/python3.11/site-packages (from sympy>=1.13.3->torch->flash-attn) (1.3.0)
+Requirement already satisfied: MarkupSafe>=2.0 in ./.devenv/state/venv/lib/python3.11/site-packages (from jinja2->torch->flash-attn) (2.1.5)
+Using cached einops-0.8.1-py3-none-any.whl (64 kB)
+Installing collected packages: einops, flash-attn
+
+Successfully installed einops-0.8.1 flash-attn-2.8.0.post2
diff --git a/LORA_TARGET_MODULES.md b/LORA_TARGET_MODULES.md
new file mode 100644
index 0000000..39aa5c8
--- /dev/null
+++ b/LORA_TARGET_MODULES.md
@@ -0,0 +1,124 @@
+# LoRA Target Modules Reference
+
+This document provides the correct target module names for different model architectures when using LoRA (Low-Rank Adaptation).
+
+## Model Architecture Detection
+
+Use the inspection script to find correct target modules:
+
+```bash
+# In the nix development environment
+python /home/centra/dev/pnn/inspect_conv1d_model.py [model_name]
+```
+
+## Common Model Architectures
+
+### GPT-2 / DialoGPT Models
+- **Model Type**: GPT2LMHeadModel
+- **Layer Type**: Conv1D (not Linear!)
+- **Base Model**: microsoft/DialoGPT-small, gpt2, gpt2-medium, gpt2-large, gpt2-xl
+
+#### Attention Modules
+- `c_attn` - Combined query, key, value projection (nf=3*hidden_size)
+- `c_proj` - Output projection
+
+#### MLP Modules
+- `mlp.c_fc` - Feed-forward up projection
+- `mlp.c_proj` - Feed-forward down projection
+
+#### Recommended Configurations
+```yaml
+# Basic stage (attention only)
+target_modules: ["c_attn", "c_proj"]
+
+# Advanced stage (attention + MLP)
+target_modules: ["c_attn", "c_proj", "mlp.c_fc", "mlp.c_proj"]
+```
+
+### LLaMA Models
+- **Model Type**: LlamaForCausalLM
+- **Layer Type**: Linear
+- **Base Model**: meta-llama/Llama-2-7b-hf, meta-llama/Llama-3.2-8B
+
+#### Attention Modules
+- `q_proj` - Query projection
+- `k_proj` - Key projection
+- `v_proj` - Value projection
+- `o_proj` - Output projection
+
+#### MLP Modules
+- `gate_proj` - Gate projection
+- `up_proj` - Up projection
+- `down_proj` - Down projection
+
+#### Recommended Configurations
+```yaml
+# Basic stage (attention only)
+target_modules: ["q_proj", "v_proj"]
+
+# Advanced stage (attention + MLP)
+target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+```
+
+### Mistral Models
+- **Model Type**: MistralForCausalLM
+- **Layer Type**: Linear
+- **Base Model**: mistralai/Mistral-7B-v0.1
+
+#### Target Modules (same as LLaMA)
+```yaml
+target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+```
+
+### Qwen Models
+- **Model Type**: QWenLMHeadModel
+- **Layer Type**: Linear
+- **Base Model**: Qwen/Qwen-7B
+
+#### Target Modules
+```yaml
+target_modules: ["c_attn", "c_proj", "w1", "w2"]
+```
+
+## Important Notes
+
+1. **Conv1D vs Linear**: GPT-2 based models use `Conv1D` layers, not `Linear` layers
+2. **Module Patterns**: Use simple patterns like `"c_attn"` rather than full paths like `"transformer.h.0.attn.c_attn"`
+3. **Testing**: Always test your configuration before training by creating a PEFT model
+4. **Architecture Variations**: Different model families use different naming conventions
+
+## Troubleshooting
+
+### Error: "Target module not found"
+- Run the inspection script to find actual module names
+- Check if the model uses Conv1D or Linear layers
+- Verify the module naming pattern for your specific model
+
+### Error: "No trainable parameters"
+- Ensure target modules exist in the model
+- Check that the module names match exactly
+- Verify the model architecture is supported by PEFT
+
+## Testing Your Configuration
+
+```python
+from peft import get_peft_model, LoraConfig, TaskType
+
+# Test configuration
+lora_config = LoraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ r=8,
+ lora_alpha=16,
+ lora_dropout=0.1,
+ target_modules=["c_attn", "c_proj"], # Your target modules
+ bias="none"
+)
+
+# Try to create PEFT model
+try:
+ peft_model = get_peft_model(model, lora_config)
+ peft_model.print_trainable_parameters()
+ print("✓ Configuration works!")
+except Exception as e:
+ print(f"✗ Configuration failed: {e}")
+```
\ No newline at end of file
diff --git a/config/README.md b/config/README.md
new file mode 100644
index 0000000..797a4a0
--- /dev/null
+++ b/config/README.md
@@ -0,0 +1,85 @@
+# Training Configuration Files
+
+This directory contains configuration files for different model sizes and use cases.
+
+## Available Configurations
+
+### Small Models (Testing)
+- `training_config.yaml` - Default configuration for small models (DialoGPT-small)
+ - Memory: ~1GB VRAM
+ - Batch size: 8
+ - No quantization
+
+### Medium Models (8B)
+- `training_config_large.yaml` - Configuration for 8B models (Llama-3.2-8B)
+ - Memory: ~12GB VRAM with 4-bit quantization
+ - Batch size: 1, gradient accumulation: 16-64
+ - 4-bit quantization enabled
+
+### Large Models (13B)
+- `training_config_13b.yaml` - Configuration for 13B models
+ - Memory: ~16GB VRAM with 4-bit quantization
+ - Batch size: 1, gradient accumulation: 32-128
+ - Higher LoRA ranks (32-128)
+
+### Extra Large Models (70B)
+- `training_config_70b.yaml` - Configuration for 70B models
+ - Memory: ~40GB+ VRAM with 4-bit quantization
+ - Batch size: 1, gradient accumulation: 64-256
+ - Maximum LoRA ranks (64-256)
+ - Multi-GPU support with FSDP
+
+## Configuration Parameters
+
+### Model Settings
+- `load_in_4bit`: Enable 4-bit quantization (recommended for large models)
+- `gradient_checkpointing`: Trade compute for memory
+- `use_flash_attention_2`: Faster attention computation if available
+
+### Adapter Settings
+- `r`: LoRA rank (higher = more parameters but better capacity)
+- `lora_alpha`: LoRA scaling factor (typically 2x the rank)
+- `init_lora_weights`: Set to `true` for identity initialization
+
+### Training Settings
+- `per_device_batch_size`: Usually 1 for large models
+- `gradient_accumulation_steps`: Effective batch size multiplier
+- `learning_rate`: Lower for larger models
+- `bf16`: Use bfloat16 for better numerical stability
+
+## Usage
+
+```bash
+# For 8B models
+python scripts/train_progressive.py --config config/training_config_large.yaml
+
+# For 13B models
+python scripts/train_progressive.py --config config/training_config_13b.yaml
+
+# For 70B models (requires multiple GPUs)
+python scripts/train_progressive.py --config config/training_config_70b.yaml
+```
+
+## Memory Requirements
+
+| Model Size | VRAM (4-bit) | VRAM (16-bit) | GPUs Recommended |
+|------------|--------------|---------------|------------------|
+| 8B | 12-16GB | 32GB | 1x RTX 4090 |
+| 13B | 16-20GB | 52GB | 1x A100 |
+| 70B | 40-60GB | 140GB | 2x A100 |
+
+## Tips for Large Models
+
+1. **Start with smaller models** to validate your approach
+2. **Use gradient checkpointing** to reduce memory usage
+3. **Monitor GPU memory** during training
+4. **Use lower learning rates** for stability
+5. **Consider multi-GPU setup** for 70B+ models
+6. **Enable flash attention** if available for speed
+
+## Troubleshooting
+
+- **OOM errors**: Reduce batch size or enable gradient checkpointing
+- **Slow training**: Enable flash attention, use bf16
+- **Poor convergence**: Adjust learning rate or warmup steps
+- **Multi-GPU issues**: Check FSDP configuration
\ No newline at end of file
diff --git a/config/training_config.yaml b/config/training_config.yaml
new file mode 100644
index 0000000..ba68d02
--- /dev/null
+++ b/config/training_config.yaml
@@ -0,0 +1,36 @@
+experiment:
+ name: "progressive_reasoning_experiment"
+ base_model: "microsoft/DialoGPT-small" # Lightweight model for testing
+ output_dir: "./outputs"
+ use_wandb: false
+ wandb_project: "matsuo-llm-comp-2025"
+
+model:
+ load_in_4bit: false # Disable quantization for small model
+ bnb_4bit_compute_dtype: "bfloat16"
+ bnb_4bit_use_double_quant: true
+ device_map: "auto"
+
+progressive_stages:
+ - name: "basic_cot"
+ description: "Basic Chain-of-Thought reasoning"
+ dataset_path: "./data/basic_cot/"
+ adapter_config:
+ r: 8
+ lora_alpha: 16
+ lora_dropout: 0.1
+ target_modules: ["c_attn", "c_proj"]
+ training:
+ num_epochs: 2
+ per_device_batch_size: 8 # Increase batch size for small model
+ gradient_accumulation_steps: 2 # Reduce accumulation steps
+ learning_rate: 5e-4 # Higher learning rate for faster training
+ warmup_steps: 50
+ max_length: 1024 # Shorter sequences
+
+evaluation:
+ benchmarks:
+ - "HLE" # Humanity's Last Exam
+ - "Do-Not-Answer"
+ save_results: true
+ results_dir: "./outputs/evaluation_results"
\ No newline at end of file
diff --git a/config/training_config_13b.yaml b/config/training_config_13b.yaml
new file mode 100644
index 0000000..59fd626
--- /dev/null
+++ b/config/training_config_13b.yaml
@@ -0,0 +1,83 @@
+experiment:
+ name: "progressive_reasoning_13b"
+ base_model: "meta-llama/Llama-3.2-13B" # 13B model
+ output_dir: "./outputs"
+ use_wandb: true
+ wandb_project: "matsuo-llm-comp-2025"
+
+model:
+ load_in_4bit: true
+ bnb_4bit_compute_dtype: "bfloat16"
+ bnb_4bit_use_double_quant: true
+ bnb_4bit_quant_type: "nf4"
+ device_map: "auto"
+ gradient_checkpointing: true
+ use_flash_attention_2: true
+
+progressive_stages:
+ - name: "basic_cot"
+ description: "Basic Chain-of-Thought reasoning"
+ dataset_path: "./data/basic_cot/"
+ adapter_config:
+ r: 32 # Higher rank for 13B models
+ lora_alpha: 64
+ lora_dropout: 0.05
+ target_modules: ["q_proj", "v_proj", "k_proj", "o_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 1
+ per_device_batch_size: 1
+ gradient_accumulation_steps: 32
+ learning_rate: 1e-4
+ warmup_steps: 100
+ max_length: 2048
+ bf16: true
+ max_grad_norm: 0.3
+ weight_decay: 0.001
+
+ - name: "math_reasoning"
+ description: "Mathematical reasoning with think tags"
+ dataset_path: "./data/math_reasoning/"
+ inherit_from: "basic_cot"
+ adapter_config:
+ r: 64
+ lora_alpha: 128
+ lora_dropout: 0.05
+ target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 2
+ per_device_batch_size: 1
+ gradient_accumulation_steps: 64
+ learning_rate: 8e-5
+ warmup_steps: 200
+ max_length: 4096
+ bf16: true
+ max_grad_norm: 0.3
+
+ - name: "complex_reasoning"
+ description: "Complex multi-step reasoning"
+ dataset_path: "./data/complex_reasoning/"
+ inherit_from: "math_reasoning"
+ adapter_config:
+ r: 128 # Maximum rank for 13B models
+ lora_alpha: 256
+ lora_dropout: 0.05
+ target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 1
+ per_device_batch_size: 1
+ gradient_accumulation_steps: 128
+ learning_rate: 5e-5
+ warmup_steps: 300
+ max_length: 8192
+ bf16: true
+ max_grad_norm: 0.3
+
+evaluation:
+ benchmarks:
+ - "HLE"
+ - "Do-Not-Answer"
+ save_results: true
+ results_dir: "./outputs/evaluation_results"
\ No newline at end of file
diff --git a/config/training_config_70b.yaml b/config/training_config_70b.yaml
new file mode 100644
index 0000000..ed44f42
--- /dev/null
+++ b/config/training_config_70b.yaml
@@ -0,0 +1,101 @@
+experiment:
+ name: "progressive_reasoning_70b"
+ base_model: "meta-llama/Llama-3.2-70B" # 70B model - requires significant resources
+ output_dir: "./outputs"
+ use_wandb: true
+ wandb_project: "matsuo-llm-comp-2025"
+
+model:
+ load_in_4bit: true
+ bnb_4bit_compute_dtype: "bfloat16"
+ bnb_4bit_use_double_quant: true
+ bnb_4bit_quant_type: "nf4"
+ device_map: "auto"
+ gradient_checkpointing: true
+ use_flash_attention_2: true
+
+progressive_stages:
+ - name: "basic_cot"
+ description: "Basic Chain-of-Thought reasoning"
+ dataset_path: "./data/basic_cot/"
+ adapter_config:
+ r: 64 # Even higher rank for 70B models
+ lora_alpha: 128
+ lora_dropout: 0.05
+ target_modules: ["q_proj", "v_proj", "k_proj", "o_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 1
+ per_device_batch_size: 1
+ gradient_accumulation_steps: 64
+ learning_rate: 5e-5 # Lower learning rate for stability
+ warmup_steps: 200
+ max_length: 2048
+ bf16: true
+ max_grad_norm: 0.3
+ weight_decay: 0.001
+ dataloader_num_workers: 2
+
+ - name: "math_reasoning"
+ description: "Mathematical reasoning with think tags"
+ dataset_path: "./data/math_reasoning/"
+ inherit_from: "basic_cot"
+ adapter_config:
+ r: 128
+ lora_alpha: 256
+ lora_dropout: 0.05
+ target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 1
+ per_device_batch_size: 1
+ gradient_accumulation_steps: 128
+ learning_rate: 3e-5
+ warmup_steps: 300
+ max_length: 4096
+ bf16: true
+ max_grad_norm: 0.3
+ dataloader_num_workers: 2
+
+ - name: "complex_reasoning"
+ description: "Complex multi-step reasoning"
+ dataset_path: "./data/complex_reasoning/"
+ inherit_from: "math_reasoning"
+ adapter_config:
+ r: 256 # Maximum rank for 70B models
+ lora_alpha: 512
+ lora_dropout: 0.05
+ target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 1
+ per_device_batch_size: 1
+ gradient_accumulation_steps: 256
+ learning_rate: 2e-5
+ warmup_steps: 500
+ max_length: 8192
+ bf16: true
+ max_grad_norm: 0.3
+ dataloader_num_workers: 2
+
+evaluation:
+ benchmarks:
+ - "HLE"
+ - "Do-Not-Answer"
+ save_results: true
+ results_dir: "./outputs/evaluation_results"
+
+# Additional settings for 70B models
+optimization:
+ gradient_checkpointing: true
+ gradient_checkpointing_kwargs:
+ use_reentrant: false
+ ddp_find_unused_parameters: false
+ # Multi-GPU settings
+ fsdp: "full_shard auto_wrap"
+ fsdp_transformer_layer_cls_to_wrap: "LlamaDecoderLayer"
+ fsdp_min_num_params: 1000000
+ fsdp_config:
+ min_num_params: 1000000
+ sharding_strategy: "FULL_SHARD"
+ cpu_offload: false
\ No newline at end of file
diff --git a/config/training_config_gemma2_small.yaml b/config/training_config_gemma2_small.yaml
new file mode 100644
index 0000000..fa035d6
--- /dev/null
+++ b/config/training_config_gemma2_small.yaml
@@ -0,0 +1,91 @@
+experiment:
+ name: "progressive_reasoning_gemma2_small"
+ base_model: "google/gemma-2-2b-it" # Instruction-tuned version
+ output_dir: "./outputs"
+ use_wandb: true
+ wandb_project: "matsuo-llm-comp-2025"
+
+model:
+ load_in_4bit: false # 2B model is manageable without quantization
+ bnb_4bit_compute_dtype: "bfloat16"
+ bnb_4bit_use_double_quant: true
+ device_map: "auto"
+ gradient_checkpointing: false
+ use_flash_attention_2: false
+ use_eager_attention: true # Required for Gemma 3 models
+
+progressive_stages:
+ - name: "basic_cot"
+ description: "Basic Chain-of-Thought reasoning"
+ dataset_path: "./data/basic_cot/"
+ adapter_config:
+ r: 8 # Start with smaller rank for small model
+ lora_alpha: 16
+ lora_dropout: 0.1
+ target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 3
+ per_device_batch_size: 8 # Larger batch size for small model
+ gradient_accumulation_steps: 2
+ learning_rate: 5e-4 # Higher learning rate for small model
+ warmup_steps: 50
+ max_length: 1024
+ bf16: true
+ max_grad_norm: 1.0
+ weight_decay: 0.001
+ save_steps: 50
+ logging_steps: 10
+
+ - name: "math_reasoning"
+ description: "Mathematical reasoning with think tags"
+ dataset_path: "./data/math_reasoning/"
+ inherit_from: "basic_cot"
+ adapter_config:
+ r: 16
+ lora_alpha: 32
+ lora_dropout: 0.1
+ target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 3
+ per_device_batch_size: 4
+ gradient_accumulation_steps: 4
+ learning_rate: 3e-4
+ warmup_steps: 100
+ max_length: 2048
+ bf16: true
+ max_grad_norm: 1.0
+
+ - name: "complex_reasoning"
+ description: "Complex multi-step reasoning with Mixture-of-Thoughts"
+ dataset_path: "open-r1/Mixture-of-Thoughts" # HuggingFace dataset
+ inherit_from: "math_reasoning"
+ adapter_config:
+ r: 32
+ lora_alpha: 64
+ lora_dropout: 0.1
+ target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 1 # Large dataset, fewer epochs
+ per_device_batch_size: 2
+ gradient_accumulation_steps: 8
+ learning_rate: 2e-4
+ warmup_steps: 200
+ max_length: 4096
+ bf16: true
+ max_grad_norm: 1.0
+ save_steps: 500
+ logging_steps: 50
+ dataset_config:
+ streaming: true
+ max_samples: 30000
+ split: "train"
+
+evaluation:
+ benchmarks:
+ - "HLE"
+ - "Do-Not-Answer"
+ save_results: true
+ results_dir: "./outputs/evaluation_results"
\ No newline at end of file
diff --git a/config/training_config_gemma3_1b.yaml b/config/training_config_gemma3_1b.yaml
new file mode 100644
index 0000000..2433612
--- /dev/null
+++ b/config/training_config_gemma3_1b.yaml
@@ -0,0 +1,102 @@
+experiment:
+ name: "progressive_reasoning_gemma3_1b"
+ base_model: "google/gemma-3-1b-pt" # Using Gemma 2 2B (1B might not be available)
+ output_dir: "./outputs"
+ use_wandb: true
+ wandb_project: "matsuo-llm-comp-2025"
+
+model:
+ load_in_4bit: false
+ bnb_4bit_compute_dtype: "bfloat16"
+ bnb_4bit_use_double_quant: true
+ device_map: "auto"
+ gradient_checkpointing: false # Not needed for small models
+ use_flash_attention_2: false
+ use_eager_attention: true
+
+progressive_stages:
+ - name: "basic_cot"
+ description: "Basic Chain-of-Thought reasoning"
+ dataset_path: "./data/basic_cot/"
+ adapter_config:
+ r: 8
+ lora_alpha: 16
+ lora_dropout: 0.1
+ target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] # Gemma attention modules
+ init_lora_weights: true
+ training:
+ num_epochs: 2
+ per_device_batch_size: 8
+ gradient_accumulation_steps: 2
+ learning_rate: 5e-4
+ warmup_steps: 50
+ max_length: 1024
+ fp16: false
+ bf16: true
+ max_grad_norm: 1.0
+ weight_decay: 0.001
+ save_steps: 100
+ logging_steps: 10
+
+ - name: "math_reasoning"
+ description: "Mathematical reasoning with OpenR1-Math-220k dataset"
+ dataset_path: "open-r1/OpenR1-Math-220k" # HuggingFace dataset
+ inherit_from: "basic_cot"
+ adapter_config:
+ r: 16
+ lora_alpha: 32
+ lora_dropout: 0.1
+ target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 1 # Large dataset, fewer epochs
+ per_device_batch_size: 4
+ gradient_accumulation_steps: 4
+ learning_rate: 3e-4
+ warmup_steps: 100
+ max_length: 2048
+ bf16: true
+ max_grad_norm: 1.0
+ weight_decay: 0.001
+ save_steps: 1000
+ logging_steps: 100
+ dataset_config:
+ # OpenR1-Math-220k specific settings
+ streaming: true # Use streaming for large dataset
+ max_samples: 200000 # Limit samples for faster training
+ split: "train"
+
+ - name: "complex_reasoning"
+ description: "Complex multi-step reasoning with Mixture-of-Thoughts"
+ dataset_path: "open-r1/Mixture-of-Thoughts" # HuggingFace dataset
+ inherit_from: "math_reasoning"
+ adapter_config:
+ r: 32
+ lora_alpha: 64
+ lora_dropout: 0.1
+ target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 1 # Large dataset, fewer epochs
+ per_device_batch_size: 2
+ gradient_accumulation_steps: 8
+ learning_rate: 2e-4
+ warmup_steps: 200
+ max_length: 4096
+ bf16: true
+ max_grad_norm: 1.0
+ weight_decay: 0.001
+ save_steps: 500
+ logging_steps: 50
+ dataset_config:
+ # Mixture-of-Thoughts specific settings
+ streaming: true # Use streaming for large dataset
+ max_samples: 30000 # Limit samples for faster training
+ split: "train"
+
+evaluation:
+ benchmarks:
+ - "HLE"
+ - "Do-Not-Answer"
+ save_results: true
+ results_dir: "./outputs/evaluation_results"
diff --git a/config/training_config_gemma3_1b_cpu_offload.yaml b/config/training_config_gemma3_1b_cpu_offload.yaml
new file mode 100644
index 0000000..2b4158c
--- /dev/null
+++ b/config/training_config_gemma3_1b_cpu_offload.yaml
@@ -0,0 +1,133 @@
+experiment:
+ name: "progressive_reasoning_gemma3_1b_cpu_offload"
+ base_model: "google/gemma-3-1b-pt" # Using Gemma 3 1B
+ output_dir: "./outputs"
+ use_wandb: true
+ wandb_project: "matsuo-llm-comp-2025"
+
+model:
+ load_in_4bit: true # Enable 4-bit quantization for QLoRA
+ bnb_4bit_compute_dtype: "bfloat16"
+ bnb_4bit_use_double_quant: true
+ bnb_4bit_quant_type: "nf4"
+ device_map: "auto" # Let accelerate handle device placement
+ max_memory:
+ 0: "5GB" # Limit GPU memory to 3GB (leave room for CUDA kernels)
+ "cpu": "32GB" # Allow up to 32GB CPU RAM
+ offload_folder: "./offload" # Directory for disk offloading if needed
+ gradient_checkpointing: true # Trade compute for memory
+ use_flash_attention_2: false
+ use_eager_attention: true
+
+progressive_stages:
+ - name: "basic_cot"
+ description: "Basic Chain-of-Thought reasoning"
+ dataset_path: "./data/basic_cot/"
+ adapter_config:
+ r: 8 # Lower rank for memory efficiency
+ lora_alpha: 16
+ lora_dropout: 0.1
+ target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 2
+ per_device_batch_size: 2 # Smaller batch size
+ gradient_accumulation_steps: 8 # Compensate with gradient accumulation
+ learning_rate: 5e-4
+ warmup_steps: 50
+ max_length: 512 # Shorter sequences for memory
+ bf16: true
+ max_grad_norm: 1.0
+ weight_decay: 0.001
+ save_steps: 100
+ logging_steps: 10
+ dataloader_num_workers: 0 # Disable multiprocessing to save memory
+ optim: "paged_adamw_8bit" # Use 8-bit optimizer
+
+ - name: "math_reasoning"
+ description: "Mathematical reasoning with OpenR1-Math-220k dataset"
+ dataset_path: "open-r1/OpenR1-Math-220k"
+ inherit_from: "basic_cot"
+ adapter_config:
+ r: 16
+ lora_alpha: 32
+ lora_dropout: 0.1
+ target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 1
+ per_device_batch_size: 1 # Minimal batch size
+ gradient_accumulation_steps: 16
+ learning_rate: 3e-4
+ warmup_steps: 100
+ max_length: 1024
+ bf16: true
+ max_grad_norm: 1.0
+ weight_decay: 0.001
+ save_steps: 1000
+ logging_steps: 100
+ optim: "paged_adamw_8bit"
+ dataset_config:
+ streaming: true
+ max_samples: 200000 # Reduced for testing
+ split: "train"
+
+ - name: "complex_reasoning"
+ description: "Complex multi-step reasoning with Mixture-of-Thoughts"
+ dataset_path: "open-r1/Mixture-of-Thoughts" # HuggingFace dataset
+ inherit_from: "math_reasoning"
+ adapter_config:
+ r: 32
+ lora_alpha: 64
+ lora_dropout: 0.1
+ target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 1
+ per_device_batch_size: 1
+ gradient_accumulation_steps: 32
+ learning_rate: 2e-4
+ warmup_steps: 200
+ max_length: 2048
+ bf16: true
+ max_grad_norm: 1.0
+ weight_decay: 0.001
+ optim: "paged_adamw_8bit"
+ save_steps: 500
+ logging_steps: 50
+ dataset_config:
+ streaming: true
+ max_samples: 300000 # Limited for CPU offload config
+ split: "train"
+
+evaluation:
+ benchmarks:
+ - "HLE"
+ - "Do-Not-Answer"
+ save_results: true
+ results_dir: "./outputs/evaluation_results"
+
+# DeepSpeed configuration for advanced CPU offloading (optional)
+# Uncomment to use DeepSpeed ZeRO-2 with CPU offload
+# deepspeed:
+# zero_optimization:
+# stage: 2
+# offload_optimizer:
+# device: "cpu"
+# pin_memory: true
+# offload_param:
+# device: "cpu"
+# pin_memory: true
+# overlap_comm: true
+# contiguous_gradients: true
+# sub_group_size: 1e9
+# reduce_bucket_size: 1e6
+
+# FSDP configuration for distributed training (optional)
+# Uncomment to use FSDP with CPU offload
+# fsdp:
+# sharding_strategy: "FULL_SHARD"
+# cpu_offload: true
+# auto_wrap_policy: "TRANSFORMER_BASED_WRAP"
+# transformer_layer_cls_to_wrap: "GemmaDecoderLayer"
+# min_num_params: 1e6
diff --git a/config/training_config_large.yaml b/config/training_config_large.yaml
new file mode 100644
index 0000000..22cbd3b
--- /dev/null
+++ b/config/training_config_large.yaml
@@ -0,0 +1,98 @@
+experiment:
+ name: "progressive_reasoning_large_model"
+ base_model: "meta-llama/Llama-3.2-8B" # Or other whitelisted models
+ output_dir: "./outputs"
+ use_wandb: true
+ wandb_project: "matsuo-llm-comp-2025"
+
+model:
+ load_in_4bit: true # Enable 4-bit quantization for memory efficiency
+ bnb_4bit_compute_dtype: "bfloat16"
+ bnb_4bit_use_double_quant: true
+ bnb_4bit_quant_type: "nf4"
+ device_map: "auto"
+ # Additional memory optimizations
+ gradient_checkpointing: true
+ use_flash_attention_2: true # If available
+
+progressive_stages:
+ - name: "basic_cot"
+ description: "Basic Chain-of-Thought reasoning"
+ dataset_path: "./data/basic_cot/"
+ adapter_config:
+ r: 16 # Larger rank for bigger models
+ lora_alpha: 32
+ lora_dropout: 0.05
+ target_modules: ["q_proj", "v_proj", "k_proj", "o_proj"]
+ init_lora_weights: true # Identity initialization
+ training:
+ num_epochs: 1
+ per_device_batch_size: 1 # Small batch size for large models
+ gradient_accumulation_steps: 16 # Effective batch size = 16
+ learning_rate: 2e-4
+ warmup_steps: 100
+ max_length: 2048
+ fp16: false
+ bf16: true
+ max_grad_norm: 0.3
+ weight_decay: 0.001
+ save_steps: 50
+ logging_steps: 10
+
+ - name: "math_reasoning"
+ description: "Mathematical reasoning with think tags"
+ dataset_path: "./data/math_reasoning/"
+ inherit_from: "basic_cot"
+ adapter_config:
+ r: 32 # Increase rank for more complex tasks
+ lora_alpha: 64
+ lora_dropout: 0.05
+ target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 2
+ per_device_batch_size: 1
+ gradient_accumulation_steps: 32 # Effective batch size = 32
+ learning_rate: 1e-4
+ warmup_steps: 200
+ max_length: 4096
+ bf16: true
+ max_grad_norm: 0.3
+ weight_decay: 0.001
+
+ - name: "complex_reasoning"
+ description: "Complex multi-step reasoning"
+ dataset_path: "./data/complex_reasoning/"
+ inherit_from: "math_reasoning"
+ adapter_config:
+ r: 64 # Maximum rank for most complex tasks
+ lora_alpha: 128
+ lora_dropout: 0.05
+ target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 2
+ per_device_batch_size: 1
+ gradient_accumulation_steps: 64 # Effective batch size = 64
+ learning_rate: 5e-5
+ warmup_steps: 300
+ max_length: 8192
+ bf16: true
+ max_grad_norm: 0.3
+ weight_decay: 0.001
+
+evaluation:
+ benchmarks:
+ - "HLE"
+ - "Do-Not-Answer"
+ save_results: true
+ results_dir: "./outputs/evaluation_results"
+
+# Memory optimization settings
+optimization:
+ gradient_checkpointing: true
+ gradient_checkpointing_kwargs:
+ use_reentrant: false
+ ddp_find_unused_parameters: false
+ fsdp: "full_shard auto_wrap" # For multi-GPU setups
+ fsdp_transformer_layer_cls_to_wrap: "LlamaDecoderLayer"
\ No newline at end of file
diff --git a/config/training_config_llama_auth.yaml b/config/training_config_llama_auth.yaml
new file mode 100644
index 0000000..090df6f
--- /dev/null
+++ b/config/training_config_llama_auth.yaml
@@ -0,0 +1,85 @@
+experiment:
+ name: "progressive_reasoning_llama_auth"
+ base_model: "meta-llama/Llama-3.2-8B"
+ output_dir: "./outputs"
+ use_wandb: true
+ wandb_project: "matsuo-llm-comp-2025"
+
+model:
+ load_in_4bit: true
+ bnb_4bit_compute_dtype: "bfloat16"
+ bnb_4bit_use_double_quant: true
+ bnb_4bit_quant_type: "nf4"
+ device_map: "auto"
+ gradient_checkpointing: true
+ use_flash_attention_2: true
+ # Add your HuggingFace token here, or set HF_TOKEN environment variable
+ # hf_token: "your_token_here"
+
+progressive_stages:
+ - name: "basic_cot"
+ description: "Basic Chain-of-Thought reasoning"
+ dataset_path: "./data/basic_cot/"
+ adapter_config:
+ r: 16
+ lora_alpha: 32
+ lora_dropout: 0.05
+ target_modules: ["q_proj", "v_proj", "k_proj", "o_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 1
+ per_device_batch_size: 1
+ gradient_accumulation_steps: 16
+ learning_rate: 2e-4
+ warmup_steps: 100
+ max_length: 2048
+ bf16: true
+ max_grad_norm: 0.3
+ weight_decay: 0.001
+
+ - name: "math_reasoning"
+ description: "Mathematical reasoning with think tags"
+ dataset_path: "./data/math_reasoning/"
+ inherit_from: "basic_cot"
+ adapter_config:
+ r: 32
+ lora_alpha: 64
+ lora_dropout: 0.05
+ target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 2
+ per_device_batch_size: 1
+ gradient_accumulation_steps: 32
+ learning_rate: 1e-4
+ warmup_steps: 200
+ max_length: 4096
+ bf16: true
+ max_grad_norm: 0.3
+
+ - name: "complex_reasoning"
+ description: "Complex multi-step reasoning"
+ dataset_path: "./data/complex_reasoning/"
+ inherit_from: "math_reasoning"
+ adapter_config:
+ r: 64
+ lora_alpha: 128
+ lora_dropout: 0.05
+ target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 2
+ per_device_batch_size: 1
+ gradient_accumulation_steps: 64
+ learning_rate: 5e-5
+ warmup_steps: 300
+ max_length: 8192
+ bf16: true
+ max_grad_norm: 0.3
+
+evaluation:
+ benchmarks:
+ - "HLE"
+ - "Do-Not-Answer"
+ save_results: true
+ results_dir: "./outputs/evaluation_results"
\ No newline at end of file
diff --git a/config/training_config_public.yaml b/config/training_config_public.yaml
new file mode 100644
index 0000000..a1abea4
--- /dev/null
+++ b/config/training_config_public.yaml
@@ -0,0 +1,82 @@
+experiment:
+ name: "progressive_reasoning_public_model"
+ base_model: "microsoft/DialoGPT-medium" # Public model, no authentication needed
+ output_dir: "./outputs"
+ use_wandb: false
+ wandb_project: "matsuo-llm-comp-2025"
+
+model:
+ load_in_4bit: false # DialoGPT is smaller, quantization not needed
+ bnb_4bit_compute_dtype: "bfloat16"
+ bnb_4bit_use_double_quant: true
+ device_map: "auto"
+ gradient_checkpointing: false
+
+progressive_stages:
+ - name: "basic_cot"
+ description: "Basic Chain-of-Thought reasoning"
+ dataset_path: "./data/basic_cot/"
+ adapter_config:
+ r: 16
+ lora_alpha: 32
+ lora_dropout: 0.1
+ target_modules: ["c_attn", "c_proj"] # GPT-2 style attention modules
+ init_lora_weights: true
+ training:
+ num_epochs: 2
+ per_device_batch_size: 4
+ gradient_accumulation_steps: 4
+ learning_rate: 2e-4
+ warmup_steps: 100
+ max_length: 1024
+ fp16: false
+ bf16: false # Use fp32 for smaller models
+ max_grad_norm: 1.0
+ weight_decay: 0.001
+
+ - name: "math_reasoning"
+ description: "Mathematical reasoning with think tags"
+ dataset_path: "./data/math_reasoning/"
+ inherit_from: "basic_cot"
+ adapter_config:
+ r: 32
+ lora_alpha: 64
+ lora_dropout: 0.1
+ target_modules: ["c_attn", "c_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 3
+ per_device_batch_size: 2
+ gradient_accumulation_steps: 8
+ learning_rate: 1e-4
+ warmup_steps: 200
+ max_length: 2048
+ bf16: false
+ max_grad_norm: 1.0
+
+ - name: "complex_reasoning"
+ description: "Complex multi-step reasoning"
+ dataset_path: "./data/complex_reasoning/"
+ inherit_from: "math_reasoning"
+ adapter_config:
+ r: 64
+ lora_alpha: 128
+ lora_dropout: 0.1
+ target_modules: ["c_attn", "c_proj"]
+ init_lora_weights: true
+ training:
+ num_epochs: 2
+ per_device_batch_size: 1
+ gradient_accumulation_steps: 16
+ learning_rate: 5e-5
+ warmup_steps: 300
+ max_length: 4096
+ bf16: false
+ max_grad_norm: 1.0
+
+evaluation:
+ benchmarks:
+ - "HLE"
+ - "Do-Not-Answer"
+ save_results: true
+ results_dir: "./outputs/evaluation_results"
\ No newline at end of file
diff --git a/devenv.lock b/devenv.lock
new file mode 100644
index 0000000..d06c441
--- /dev/null
+++ b/devenv.lock
@@ -0,0 +1,139 @@
+{
+ "nodes": {
+ "devenv": {
+ "locked": {
+ "dir": "src/modules",
+ "lastModified": 1751909516,
+ "owner": "cachix",
+ "repo": "devenv",
+ "rev": "36e4cf7d6cb89862e69efce4e5c147ac2e4d38f9",
+ "type": "github"
+ },
+ "original": {
+ "dir": "src/modules",
+ "owner": "cachix",
+ "repo": "devenv",
+ "type": "github"
+ }
+ },
+ "flake-compat": {
+ "flake": false,
+ "locked": {
+ "lastModified": 1747046372,
+ "owner": "edolstra",
+ "repo": "flake-compat",
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
+ "type": "github"
+ },
+ "original": {
+ "owner": "edolstra",
+ "repo": "flake-compat",
+ "type": "github"
+ }
+ },
+ "flake-compat_2": {
+ "flake": false,
+ "locked": {
+ "lastModified": 1747046372,
+ "owner": "edolstra",
+ "repo": "flake-compat",
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
+ "type": "github"
+ },
+ "original": {
+ "owner": "edolstra",
+ "repo": "flake-compat",
+ "type": "github"
+ }
+ },
+ "git-hooks": {
+ "inputs": {
+ "flake-compat": "flake-compat",
+ "gitignore": "gitignore",
+ "nixpkgs": [
+ "nixpkgs"
+ ]
+ },
+ "locked": {
+ "lastModified": 1750779888,
+ "owner": "cachix",
+ "repo": "git-hooks.nix",
+ "rev": "16ec914f6fb6f599ce988427d9d94efddf25fe6d",
+ "type": "github"
+ },
+ "original": {
+ "owner": "cachix",
+ "repo": "git-hooks.nix",
+ "type": "github"
+ }
+ },
+ "gitignore": {
+ "inputs": {
+ "nixpkgs": [
+ "git-hooks",
+ "nixpkgs"
+ ]
+ },
+ "locked": {
+ "lastModified": 1709087332,
+ "owner": "hercules-ci",
+ "repo": "gitignore.nix",
+ "rev": "637db329424fd7e46cf4185293b9cc8c88c95394",
+ "type": "github"
+ },
+ "original": {
+ "owner": "hercules-ci",
+ "repo": "gitignore.nix",
+ "type": "github"
+ }
+ },
+ "nixpkgs": {
+ "locked": {
+ "lastModified": 1751792365,
+ "owner": "NixOS",
+ "repo": "nixpkgs",
+ "rev": "1fd8bada0b6117e6c7eb54aad5813023eed37ccb",
+ "type": "github"
+ },
+ "original": {
+ "owner": "NixOS",
+ "ref": "nixos-unstable",
+ "repo": "nixpkgs",
+ "type": "github"
+ }
+ },
+ "nixpkgs-python": {
+ "inputs": {
+ "flake-compat": "flake-compat_2",
+ "nixpkgs": [
+ "nixpkgs"
+ ]
+ },
+ "locked": {
+ "lastModified": 1749760516,
+ "owner": "cachix",
+ "repo": "nixpkgs-python",
+ "rev": "908dbb466af5955ea479ac95953333fd64387216",
+ "type": "github"
+ },
+ "original": {
+ "owner": "cachix",
+ "repo": "nixpkgs-python",
+ "type": "github"
+ }
+ },
+ "root": {
+ "inputs": {
+ "devenv": "devenv",
+ "git-hooks": "git-hooks",
+ "nixpkgs": "nixpkgs",
+ "nixpkgs-python": "nixpkgs-python",
+ "pre-commit-hooks": [
+ "git-hooks"
+ ]
+ }
+ }
+ },
+ "root": "root",
+ "version": 7
+}
diff --git a/flake-minimal.nix b/flake-minimal.nix
new file mode 100644
index 0000000..d5fde54
--- /dev/null
+++ b/flake-minimal.nix
@@ -0,0 +1,95 @@
+{
+ description = "Progressive LLM Training for 松尾研LLMコンペ2025 (Minimal)";
+
+ inputs = {
+ nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
+ flake-utils.url = "github:numtide/flake-utils";
+ };
+
+ outputs = { self, nixpkgs, flake-utils }:
+ flake-utils.lib.eachDefaultSystem (system:
+ let
+ pkgs = import nixpkgs {
+ inherit system;
+ config = {
+ allowUnfree = true;
+ cudaSupport = true;
+ };
+ };
+
+ # Python 3.11 for better compatibility
+ python = pkgs.python311;
+
+ # Minimal Python packages
+ pythonWithPackages = python.withPackages (ps: with ps; [
+ # Core essentials only
+ torch
+ transformers
+ numpy
+
+ # Essential dependencies
+ pyyaml
+
+ # Build tools
+ pip
+ setuptools
+ wheel
+ ]);
+
+ in
+ {
+ devShells.default = pkgs.mkShell {
+ buildInputs = with pkgs; [
+ # Python with packages
+ pythonWithPackages
+
+ # Build tools
+ gcc
+ cmake
+ ninja
+ pkg-config
+
+ # Git
+ git
+ git-lfs
+
+ # Libraries needed for Python packages
+ openssl
+ zlib
+ glib
+ stdenv.cc.cc.lib
+
+ # CUDA support
+ cudaPackages.cudatoolkit
+ cudaPackages.cudnn
+ ];
+
+ shellHook = ''
+ echo "🚀 Progressive LLM Training Environment (Minimal)"
+ echo "Python version: $(python --version)"
+ echo "PyTorch version: $(python -c 'import torch; print(torch.__version__)')"
+ echo "CUDA available: $(python -c 'import torch; print(torch.cuda.is_available())')"
+
+ # Set up CUDA environment
+ export CUDA_HOME=${pkgs.cudaPackages.cudatoolkit}
+ export CUDA_PATH=${pkgs.cudaPackages.cudatoolkit}
+ export LD_LIBRARY_PATH=${pkgs.cudaPackages.cudatoolkit}/lib:${pkgs.cudaPackages.cudnn}/lib:${pkgs.stdenv.cc.cc.lib}/lib:$LD_LIBRARY_PATH
+
+ # Set Python path
+ export PYTHONPATH=$PWD/src:$PYTHONPATH
+
+ echo ""
+ echo "Note: This is a minimal configuration. Install additional packages with pip as needed:"
+ echo " pip install accelerate peft trl datasets bitsandbytes wandb jsonlines scikit-learn sentencepiece protobuf"
+ echo " pip install flash-attn --no-build-isolation"
+ '';
+
+ # Environment variables
+ CUDA_HOME = "${pkgs.cudaPackages.cudatoolkit}";
+ CUDA_PATH = "${pkgs.cudaPackages.cudatoolkit}";
+ NIX_SHELL_PRESERVE_PROMPT = 1;
+ LOCALE_ARCHIVE = "${pkgs.glibcLocales}/lib/locale/locale-archive";
+ LC_ALL = "en_US.UTF-8";
+ };
+ });
+}
\ No newline at end of file
diff --git a/flake.lock b/flake.lock
new file mode 100644
index 0000000..bd80c39
--- /dev/null
+++ b/flake.lock
@@ -0,0 +1,61 @@
+{
+ "nodes": {
+ "flake-utils": {
+ "inputs": {
+ "systems": "systems"
+ },
+ "locked": {
+ "lastModified": 1731533236,
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
+ "owner": "numtide",
+ "repo": "flake-utils",
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
+ "type": "github"
+ },
+ "original": {
+ "owner": "numtide",
+ "repo": "flake-utils",
+ "type": "github"
+ }
+ },
+ "nixpkgs": {
+ "locked": {
+ "lastModified": 1751792365,
+ "narHash": "sha256-J1kI6oAj25IG4EdVlg2hQz8NZTBNYvIS0l4wpr9KcUo=",
+ "owner": "NixOS",
+ "repo": "nixpkgs",
+ "rev": "1fd8bada0b6117e6c7eb54aad5813023eed37ccb",
+ "type": "github"
+ },
+ "original": {
+ "owner": "NixOS",
+ "ref": "nixos-unstable",
+ "repo": "nixpkgs",
+ "type": "github"
+ }
+ },
+ "root": {
+ "inputs": {
+ "flake-utils": "flake-utils",
+ "nixpkgs": "nixpkgs"
+ }
+ },
+ "systems": {
+ "locked": {
+ "lastModified": 1681028828,
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
+ "owner": "nix-systems",
+ "repo": "default",
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
+ "type": "github"
+ },
+ "original": {
+ "owner": "nix-systems",
+ "repo": "default",
+ "type": "github"
+ }
+ }
+ },
+ "root": "root",
+ "version": 7
+}
diff --git a/flake.nix b/flake.nix
new file mode 100644
index 0000000..d0d9d05
--- /dev/null
+++ b/flake.nix
@@ -0,0 +1,195 @@
+{
+ description = "Progressive LLM Training for 松尾研LLMコンペ2025";
+
+ inputs = {
+ nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
+ flake-utils.url = "github:numtide/flake-utils";
+ };
+
+ outputs = { self, nixpkgs, flake-utils }:
+ flake-utils.lib.eachDefaultSystem (system:
+ let
+ pkgs = import nixpkgs {
+ inherit system;
+ config = {
+ allowUnfree = true;
+ cudaSupport = true;
+ };
+ overlays = [
+ (final: prev: {
+ python311 = prev.python311.override {
+ packageOverrides = python-self: python-super: {
+ # Disable tests for problematic packages
+ pytest-doctestplus = python-super.pytest-doctestplus.overrideAttrs (oldAttrs: {
+ doCheck = false;
+ doInstallCheck = false;
+ pytestCheckPhase = "echo 'Skipping tests'";
+ });
+ # Also disable tests for jupyter-related packages if they cause issues
+ jupyter = python-super.jupyter.overrideAttrs (oldAttrs: {
+ doCheck = false;
+ doInstallCheck = false;
+ });
+ notebook = python-super.notebook.overrideAttrs (oldAttrs: {
+ doCheck = false;
+ doInstallCheck = false;
+ });
+ # Disable tests for psycopg and psycopg2
+ psycopg = python-super.psycopg.overrideAttrs (oldAttrs: {
+ doCheck = false;
+ doInstallCheck = false;
+ pytestCheckPhase = "echo 'Skipping tests'";
+ pythonImportsCheck = []; # Disable import checks
+ });
+ psycopg2 = python-super.psycopg2.overrideAttrs (oldAttrs: {
+ doCheck = false;
+ doInstallCheck = false;
+ pytestCheckPhase = "echo 'Skipping tests'";
+ pythonImportsCheck = []; # Disable import checks
+ });
+ # Disable tests for sqlframe
+ sqlframe = python-super.sqlframe.overrideAttrs (oldAttrs: {
+ doCheck = false;
+ doInstallCheck = false;
+ pytestCheckPhase = "echo 'Skipping tests'";
+ pythonImportsCheck = []; # Disable import checks
+ });
+ # Disable tests for accelerate
+ accelerate = python-super.accelerate.overrideAttrs (oldAttrs: {
+ doCheck = false;
+ doInstallCheck = false;
+ pytestCheckPhase = "echo 'Skipping tests'";
+ pythonImportsCheck = []; # Disable import checks
+ });
+ };
+ };
+ })
+ ];
+ };
+
+ # Python 3.11 for better compatibility
+ python = pkgs.python311;
+
+ # Python packages
+ pythonWithPackages = python.withPackages (ps: with ps; [
+ # Core ML packages
+ torch
+ torchvision
+ torchaudio
+ transformers
+ accelerate
+ datasets
+ tokenizers
+ scikit-learn
+
+ # Required dependencies from requirements.txt
+ pyyaml
+ jsonlines
+ sentencepiece
+ protobuf
+
+ # Additional useful packages
+ numpy
+ scipy
+ matplotlib
+ jupyter
+ notebook
+ ipython
+ pandas
+ rich # For TUI
+
+ # Development tools
+ black
+ flake8
+ pytest
+ mypy
+
+ # Build tools
+ pip
+ setuptools
+ wheel
+
+ # LLM specific packages
+ peft
+ trl
+ bitsandbytes
+ wandb
+ ]);
+
+ in
+ {
+ devShells.default = pkgs.mkShell {
+ buildInputs = with pkgs; [
+ # Python with packages
+ pythonWithPackages
+
+ # Build tools
+ gcc
+ cmake
+ ninja
+ pkg-config
+
+ # Git
+ git
+ git-lfs
+
+ # Development tools
+ htop
+ tmux
+ vim
+
+ # Libraries needed for Python packages
+ openssl
+ zlib
+ glib
+ stdenv.cc.cc.lib
+
+ # CUDA support
+ cudaPackages.cudatoolkit
+ cudaPackages.cudnn
+ ];
+
+ shellHook = ''
+ echo "🚀 Progressive LLM Training Environment"
+ echo "Python version: $(python --version)"
+ echo "PyTorch version: $(python -c 'import torch; print(torch.__version__)')"
+ echo "CUDA available: $(python -c 'import torch; print(torch.cuda.is_available())')"
+
+ # Set up CUDA environment
+ export CUDA_HOME=${pkgs.cudaPackages.cudatoolkit}
+ export CUDA_PATH=${pkgs.cudaPackages.cudatoolkit}
+ export LD_LIBRARY_PATH=${pkgs.cudaPackages.cudatoolkit}/lib:${pkgs.cudaPackages.cudnn}/lib:${pkgs.stdenv.cc.cc.lib}/lib:$LD_LIBRARY_PATH
+
+ # Set Python path
+ export PYTHONPATH=$PWD/src:$PYTHONPATH
+
+ echo ""
+ echo "Available commands:"
+ echo " python scripts/train_progressive.py # Start training"
+ echo " python scripts/evaluate.py # Evaluate model"
+ echo " jupyter notebook # Start Jupyter"
+ echo ""
+
+ # Create data directory if not exists
+ mkdir -p data
+
+ # Prepare sample data if not exists
+ if [ ! -f "data/basic_cot/train.jsonl" ]; then
+ echo "Preparing sample datasets..."
+ python -c "from src.data_utils import prepare_sample_datasets; prepare_sample_datasets()" || echo "Sample data preparation skipped"
+ fi
+
+ # Note about flash-attn
+ echo "Note: flash-attn is not included in nixpkgs. If needed, install manually with:"
+ echo " pip install flash-attn --no-build-isolation"
+ '';
+
+ # Environment variables
+ CUDA_HOME = "${pkgs.cudaPackages.cudatoolkit}";
+ CUDA_PATH = "${pkgs.cudaPackages.cudatoolkit}";
+ NIX_SHELL_PRESERVE_PROMPT = 1;
+ LOCALE_ARCHIVE = "${pkgs.glibcLocales}/lib/locale/locale-archive";
+ LC_ALL = "en_US.UTF-8";
+ };
+ });
+}
\ No newline at end of file
diff --git a/requirements-cpu.txt b/requirements-cpu.txt
new file mode 100644
index 0000000..6e3167e
--- /dev/null
+++ b/requirements-cpu.txt
@@ -0,0 +1,15 @@
+# CPU version of PyTorch
+torch>=2.0.0 --index-url https://download.pytorch.org/whl/cpu
+transformers>=4.40.0
+accelerate>=0.27.0
+peft>=0.11.0
+trl>=0.9.0
+datasets>=2.18.0
+bitsandbytes>=0.43.0
+wandb>=0.16.0
+pyyaml>=6.0
+jsonlines>=4.0.0
+scikit-learn>=1.3.0
+# flash-attn is not needed for CPU version
+sentencepiece>=0.2.0
+protobuf>=4.25.0
\ No newline at end of file
diff --git a/requirements-torch.txt b/requirements-torch.txt
new file mode 100644
index 0000000..7a1beba
--- /dev/null
+++ b/requirements-torch.txt
@@ -0,0 +1,3 @@
+--index-url https://download.pytorch.org/whl/cu128
+torch>=2.0.0
+torchaudio>=2.0.0
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..534ab7a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,13 @@
+transformers>=4.40.0
+accelerate>=0.27.0
+peft>=0.11.0
+trl>=0.9.0
+datasets>=2.18.0
+bitsandbytes>=0.43.0
+wandb>=0.16.0
+pyyaml>=6.0
+jsonlines>=4.0.0
+scikit-learn>=1.3.0
+# flash-attn>=2.5.0 # Install separately with --no-build-isolation
+sentencepiece>=0.2.0
+protobuf>=4.25.0
diff --git a/scripts/analyze_adapter_size.py b/scripts/analyze_adapter_size.py
new file mode 100755
index 0000000..b05b0d3
--- /dev/null
+++ b/scripts/analyze_adapter_size.py
@@ -0,0 +1,137 @@
+#!/usr/bin/env python3
+"""
+Analyze the size and structure of LoRA adapters
+"""
+
+import sys
+from pathlib import Path
+import torch
+import yaml
+from peft import PeftModel, LoraConfig
+
+# Add src to path
+sys.path.append(str(Path(__file__).parent.parent))
+
+from src.progressive_model import ProgressiveReasoningModel
+
+
+def analyze_adapter_sizes():
+ # Load configuration
+ with open("config/training_config.yaml") as f:
+ config = yaml.safe_load(f)
+
+ print("=" * 60)
+ print("LoRA Adapter Size Analysis")
+ print("=" * 60)
+
+ # Get adapter configuration from config
+ basic_cot_config = config["progressive_stages"][0]
+ adapter_config = basic_cot_config["adapter_config"]
+
+ print(f"\nConfiguration for 'basic_cot' adapter:")
+ print(f" - r (rank): {adapter_config['r']}")
+ print(f" - lora_alpha: {adapter_config['lora_alpha']}")
+ print(f" - lora_dropout: {adapter_config['lora_dropout']}")
+ print(f" - target_modules: {adapter_config['target_modules']}")
+
+ # Load the base model to get dimensions
+ print("\nLoading base model to analyze dimensions...")
+ model_wrapper = ProgressiveReasoningModel(config)
+ model_wrapper.setup_base_model()
+
+ # Analyze model architecture
+ print(f"\nBase model: {config['experiment']['base_model']}")
+
+ # Count parameters in base model
+ total_params = sum(p.numel() for p in model_wrapper.model.parameters())
+ print(f"Total base model parameters: {total_params:,}")
+
+ # Load saved adapter if it exists
+ adapter_path = Path(config["experiment"]["output_dir"]) / "adapters" / "basic_cot"
+ if adapter_path.exists():
+ print(f"\nLoading saved adapter from: {adapter_path}")
+
+ # Load adapter state dict
+ adapter_model_path = adapter_path / "adapter_model.safetensors"
+ if not adapter_model_path.exists():
+ adapter_model_path = adapter_path / "adapter_model.bin"
+
+ if adapter_model_path.exists():
+ if adapter_model_path.suffix == ".safetensors":
+ from safetensors.torch import load_file
+ adapter_weights = load_file(adapter_model_path)
+ else:
+ adapter_weights = torch.load(adapter_model_path, map_location="cpu")
+
+ print("\nLoRA Adapter Layer Details:")
+ print("-" * 60)
+
+ total_lora_params = 0
+ layer_info = {}
+
+ for name, tensor in adapter_weights.items():
+ size = tensor.numel()
+ total_lora_params += size
+
+ # Parse layer name
+ parts = name.split('.')
+ if 'lora_A' in name or 'lora_B' in name:
+ # Extract module info
+ module_name = '.'.join(parts[:-2])
+ lora_type = parts[-2] # lora_A or lora_B
+
+ if module_name not in layer_info:
+ layer_info[module_name] = {}
+
+ layer_info[module_name][lora_type] = {
+ 'shape': list(tensor.shape),
+ 'params': size
+ }
+
+ # Display layer information
+ for module, info in sorted(layer_info.items()):
+ print(f"\nModule: {module}")
+ if 'lora_A' in info and 'lora_B' in info:
+ shape_a = info['lora_A']['shape']
+ shape_b = info['lora_B']['shape']
+ params_a = info['lora_A']['params']
+ params_b = info['lora_B']['params']
+
+ print(f" LoRA A: {shape_a} = {params_a:,} parameters")
+ print(f" LoRA B: {shape_b} = {params_b:,} parameters")
+ print(f" Total: {params_a + params_b:,} parameters")
+
+ # Calculate original layer size (approximation)
+ original_size = shape_a[1] * shape_b[0]
+ compression_ratio = original_size / (params_a + params_b)
+ print(f" Original layer size (approx): {original_size:,} parameters")
+ print(f" Compression ratio: {compression_ratio:.1f}x")
+
+ print("\n" + "=" * 60)
+ print(f"Total LoRA parameters: {total_lora_params:,}")
+ print(f"Percentage of base model: {(total_lora_params / total_params) * 100:.2f}%")
+
+ # Calculate theoretical size
+ r = adapter_config['r']
+ num_modules = len(adapter_config['target_modules'])
+
+ # For GPT models, typical dimensions
+ if "DialoGPT" in config['experiment']['base_model']:
+ hidden_size = 768 # DialoGPT-small uses 768
+ print(f"\nTheoretical calculation (hidden_size={hidden_size}, r={r}):")
+ print(f" Per module: 2 * {hidden_size} * {r} = {2 * hidden_size * r:,} parameters")
+ print(f" Total ({num_modules} modules): {2 * hidden_size * r * num_modules:,} parameters")
+ else:
+ print(f"\nNo saved adapter found at: {adapter_path}")
+ print("Run training first to generate the adapter.")
+
+ # Show theoretical sizes based on config
+ r = adapter_config['r']
+ print(f"\nTheoretical LoRA sizes with r={r}:")
+ print(f" For hidden_size=768 (DialoGPT-small): {2 * 768 * r:,} params per module")
+ print(f" For hidden_size=1024 (medium models): {2 * 1024 * r:,} params per module")
+ print(f" For hidden_size=1280 (GPT-2 large): {2 * 1280 * r:,} params per module")
+
+
+if __name__ == "__main__":
+ analyze_adapter_sizes()
\ No newline at end of file
diff --git a/scripts/check_vram.py b/scripts/check_vram.py
new file mode 100644
index 0000000..869dc02
--- /dev/null
+++ b/scripts/check_vram.py
@@ -0,0 +1,199 @@
+#!/usr/bin/env python3
+"""
+Check VRAM usage and model memory requirements
+"""
+
+import torch
+import psutil
+import sys
+from pathlib import Path
+import yaml
+
+# Add src to path
+sys.path.append(str(Path(__file__).parent.parent))
+
+from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
+
+
+def get_memory_info():
+ """Get current memory usage"""
+ if torch.cuda.is_available():
+ print("=== CUDA Information ===")
+ print(f"CUDA available: {torch.cuda.is_available()}")
+ print(f"CUDA device: {torch.cuda.get_device_name(0)}")
+ print(f"CUDA device count: {torch.cuda.device_count()}")
+
+ # Get VRAM info
+ vram_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
+ vram_reserved = torch.cuda.memory_reserved(0) / 1024**3
+ vram_allocated = torch.cuda.memory_allocated(0) / 1024**3
+ vram_free = vram_total - vram_allocated
+
+ print(f"\n=== VRAM Usage ===")
+ print(f"Total VRAM: {vram_total:.2f} GB")
+ print(f"Allocated VRAM: {vram_allocated:.2f} GB")
+ print(f"Reserved VRAM: {vram_reserved:.2f} GB")
+ print(f"Free VRAM: {vram_free:.2f} GB")
+ else:
+ print("CUDA not available!")
+
+ # Get system RAM info
+ ram = psutil.virtual_memory()
+ print(f"\n=== System RAM ===")
+ print(f"Total RAM: {ram.total / 1024**3:.2f} GB")
+ print(f"Available RAM: {ram.available / 1024**3:.2f} GB")
+ print(f"Used RAM: {ram.used / 1024**3:.2f} GB ({ram.percent}%)")
+
+
+def estimate_model_size(model_name: str, quantization: str = None):
+ """Estimate model memory requirements"""
+ print(f"\n=== Model Memory Estimation ===")
+ print(f"Model: {model_name}")
+
+ # Common model sizes (in billions of parameters)
+ model_sizes = {
+ "gemma-2-2b": 2.5,
+ "gemma-3-1b": 1.2,
+ "llama-3.2-8b": 8,
+ "llama-3.2-13b": 13,
+ "llama-3.2-70b": 70,
+ }
+
+ # Find model size
+ model_key = None
+ for key in model_sizes:
+ if key in model_name.lower():
+ model_key = key
+ break
+
+ if model_key:
+ params_billions = model_sizes[model_key]
+
+ # Memory estimates (rough)
+ fp32_gb = params_billions * 4 # 4 bytes per parameter
+ fp16_gb = params_billions * 2 # 2 bytes per parameter
+ int8_gb = params_billions * 1 # 1 byte per parameter
+ int4_gb = params_billions * 0.5 # 0.5 bytes per parameter
+
+ print(f"Estimated parameters: {params_billions}B")
+ print(f"Memory requirements:")
+ print(f" FP32: ~{fp32_gb:.1f} GB")
+ print(f" FP16/BF16: ~{fp16_gb:.1f} GB")
+ print(f" INT8: ~{int8_gb:.1f} GB")
+ print(f" INT4 (QLoRA): ~{int4_gb:.1f} GB")
+
+ # Add overhead for activations and gradients
+ print(f"\nWith training overhead:")
+ print(f" FP16 + LoRA: ~{fp16_gb * 1.5:.1f} GB")
+ print(f" INT4 + QLoRA: ~{int4_gb * 1.5:.1f} GB")
+ else:
+ print("Model size not recognized, unable to estimate memory requirements")
+
+
+def suggest_offloading_strategies():
+ """Suggest CPU offloading strategies"""
+ print("\n=== CPU Offloading Strategies ===")
+ print("\n1. **Device Map Auto with CPU Offload**")
+ print(" ```python")
+ print(" device_map = {")
+ print(" 'model.embed_tokens': 'cpu',")
+ print(" 'model.layers.0': 0, # GPU")
+ print(" 'model.layers.1': 0, # GPU")
+ print(" 'model.layers.2': 'cpu', # CPU")
+ print(" # ... distribute layers between GPU and CPU")
+ print(" }")
+ print(" ```")
+
+ print("\n2. **Accelerate's CPU Offload**")
+ print(" ```yaml")
+ print(" model:")
+ print(" device_map: 'auto'")
+ print(" max_memory:")
+ print(" 0: '4GB' # Limit GPU memory")
+ print(" 'cpu': '20GB' # Allow CPU memory")
+ print(" ```")
+
+ print("\n3. **DeepSpeed ZeRO-Offload**")
+ print(" - ZeRO-2: Offload optimizer states to CPU")
+ print(" - ZeRO-3: Offload optimizer states and parameters to CPU")
+ print(" ```yaml")
+ print(" deepspeed:")
+ print(" zero_optimization:")
+ print(" stage: 2")
+ print(" offload_optimizer:")
+ print(" device: 'cpu'")
+ print(" ```")
+
+ print("\n4. **Gradient Checkpointing + CPU Offload**")
+ print(" - Trade compute for memory")
+ print(" - Combine with layer-wise CPU offloading")
+
+ print("\n5. **QLoRA with CPU Offload**")
+ print(" - 4-bit quantization reduces base model size")
+ print(" - Only LoRA parameters on GPU")
+ print(" - Base model layers can be on CPU")
+
+
+def check_config_compatibility(config_path: str):
+ """Check if config is compatible with CPU offloading"""
+ if Path(config_path).exists():
+ with open(config_path) as f:
+ config = yaml.safe_load(f)
+
+ print(f"\n=== Config Analysis: {config_path} ===")
+ model_config = config.get("model", {})
+
+ print(f"Current settings:")
+ print(f" 4-bit quantization: {model_config.get('load_in_4bit', False)}")
+ print(f" Gradient checkpointing: {model_config.get('gradient_checkpointing', False)}")
+ print(f" Device map: {model_config.get('device_map', 'None')}")
+
+ if model_config.get('load_in_4bit', False):
+ print("✓ Already using 4-bit quantization (good for memory)")
+ else:
+ print("✗ Consider enabling 4-bit quantization")
+
+ if not model_config.get('gradient_checkpointing', False):
+ print("✗ Consider enabling gradient checkpointing")
+
+
+def main():
+ """Main function"""
+ print("VRAM and Memory Analysis for Progressive LLM Training")
+ print("=" * 60)
+
+ # Get memory info
+ get_memory_info()
+
+ # Estimate model sizes
+ models = [
+ "google/gemma-2-2b-it",
+ "google/gemma-3-1b-pt",
+ "meta-llama/Llama-3.2-8B",
+ ]
+
+ for model in models:
+ estimate_model_size(model)
+
+ # Suggest strategies
+ suggest_offloading_strategies()
+
+ # Check configs
+ configs = [
+ "config/training_config_gemma3_1b.yaml",
+ "config/training_config_gemma2_small.yaml",
+ ]
+
+ for config in configs:
+ check_config_compatibility(config)
+
+ print("\n=== Recommendations ===")
+ print("1. Start with QLoRA (4-bit) if not already enabled")
+ print("2. Use device_map with max_memory limits")
+ print("3. Enable gradient checkpointing")
+ print("4. Consider DeepSpeed for advanced offloading")
+ print("5. Monitor actual usage during training")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/scripts/compare_models_tui.py b/scripts/compare_models_tui.py
new file mode 100755
index 0000000..b2913b1
--- /dev/null
+++ b/scripts/compare_models_tui.py
@@ -0,0 +1,183 @@
+#!/usr/bin/env python3
+"""
+TUI for comparing original and trained models
+"""
+
+import sys
+from pathlib import Path
+import yaml
+import torch
+from rich.console import Console
+from rich.panel import Panel
+from rich.columns import Columns
+from rich.prompt import Prompt
+from rich.text import Text
+from rich.layout import Layout
+from rich.live import Live
+from rich.table import Table
+import time
+
+# Add src to path
+sys.path.append(str(Path(__file__).parent.parent))
+
+from src.progressive_model import ProgressiveReasoningModel
+
+
+class ModelCompareTUI:
+ def __init__(self, config_path: str = "config/training_config.yaml"):
+ self.console = Console()
+
+ # Load configuration
+ with open(config_path) as f:
+ self.config = yaml.safe_load(f)
+
+ # Initialize models
+ self.console.print("[yellow]Loading models...[/yellow]")
+
+ # Original model
+ self.original_model = ProgressiveReasoningModel(self.config)
+ self.original_model.setup_base_model()
+
+ # Trained model
+ self.trained_model = ProgressiveReasoningModel(self.config)
+ self.trained_model.setup_base_model()
+
+ # Load the trained adapter if it exists
+ adapter_path = Path(self.config["experiment"]["output_dir"]) / "adapters" / "basic_cot"
+ if adapter_path.exists():
+ self.console.print(f"[green]Loading trained adapter from: {adapter_path}[/green]")
+ self.trained_model.load_for_inference(["basic_cot"])
+ else:
+ self.console.print("[red]No trained adapter found. Please run training first.[/red]")
+ self.console.print("[yellow]Both models will show original behavior.[/yellow]")
+
+ self.console.print("[green]Models loaded successfully![/green]\n")
+
+ def generate_response(self, model, prompt: str, with_think_tags: bool = True) -> str:
+ """Generate response from a model"""
+ # For trained model, encourage think tags
+ if with_think_tags and model == self.trained_model:
+ formatted_prompt = f"{prompt}\n\nPlease think step by step."
+ else:
+ formatted_prompt = prompt
+
+ inputs = model.tokenizer(formatted_prompt, return_tensors="pt").to(model.model.device)
+
+ with torch.no_grad():
+ outputs = model.model.generate(
+ **inputs,
+ max_length=512,
+ temperature=0.7,
+ do_sample=True,
+ top_p=0.95,
+ pad_token_id=model.tokenizer.pad_token_id,
+ eos_token_id=model.tokenizer.eos_token_id
+ )
+
+ response = model.tokenizer.decode(outputs[0], skip_special_tokens=True)
+
+ # Extract response after prompt
+ response = response[len(formatted_prompt):].strip()
+
+ return response
+
+ def create_comparison_panel(self, prompt: str, original_response: str, trained_response: str) -> Panel:
+ """Create a panel showing the comparison"""
+ # Create table
+ table = Table(show_header=True, header_style="bold magenta", expand=True)
+ table.add_column("Original Model", style="cyan", width=50)
+ table.add_column("Trained Model (with CoT)", style="green", width=50)
+
+ table.add_row(original_response, trained_response)
+
+ return Panel(
+ table,
+ title=f"[bold yellow]Prompt: {prompt}[/bold yellow]",
+ border_style="blue"
+ )
+
+ def run_interactive_mode(self):
+ """Run interactive comparison mode"""
+ self.console.print("\n[bold cyan]Model Comparison TUI[/bold cyan]")
+ self.console.print("Compare responses from original and trained models\n")
+ self.console.print("[dim]Type 'quit' or 'exit' to leave[/dim]\n")
+
+ while True:
+ # Get user prompt
+ prompt = Prompt.ask("\n[bold yellow]Enter your prompt[/bold yellow]")
+
+ if prompt.lower() in ['quit', 'exit']:
+ self.console.print("\n[yellow]Goodbye![/yellow]")
+ break
+
+ # Generate responses
+ self.console.print("\n[dim]Generating responses...[/dim]")
+
+ start_time = time.time()
+ original_response = self.generate_response(self.original_model, prompt, with_think_tags=False)
+ original_time = time.time() - start_time
+
+ start_time = time.time()
+ trained_response = self.generate_response(self.trained_model, prompt, with_think_tags=True)
+ trained_time = time.time() - start_time
+
+ # Display comparison
+ panel = self.create_comparison_panel(prompt, original_response, trained_response)
+ self.console.print(panel)
+
+ # Show generation times
+ self.console.print(f"\n[dim]Generation times - Original: {original_time:.2f}s, Trained: {trained_time:.2f}s[/dim]")
+
+ def run_benchmark_mode(self):
+ """Run benchmark with predefined prompts"""
+ test_prompts = [
+ "What is 156 + 389?",
+ "If I have 23 apples and buy 17 more, how many do I have?",
+ "A store has 145 items. If 38 are sold, how many remain?",
+ "What is 45 * 12?",
+ "Explain why 2 + 2 = 4",
+ "If a train travels 80 km/h for 2.5 hours, how far does it go?",
+ "What is the sum of all numbers from 1 to 10?",
+ "How many minutes are in 3.5 hours?",
+ ]
+
+ self.console.print("\n[bold cyan]Running Benchmark Comparison[/bold cyan]\n")
+
+ for i, prompt in enumerate(test_prompts, 1):
+ self.console.print(f"[bold]Test {i}/{len(test_prompts)}[/bold]")
+
+ # Generate responses
+ original_response = self.generate_response(self.original_model, prompt, with_think_tags=False)
+ trained_response = self.generate_response(self.trained_model, prompt, with_think_tags=True)
+
+ # Display comparison
+ panel = self.create_comparison_panel(prompt, original_response, trained_response)
+ self.console.print(panel)
+ self.console.print("")
+
+ self.console.print("[green]Benchmark completed![/green]")
+
+
+def main():
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Compare original and trained models")
+ parser.add_argument("--mode", choices=["interactive", "benchmark"], default="interactive",
+ help="Mode to run the comparison")
+ parser.add_argument("--config", default="config/training_config.yaml",
+ help="Path to configuration file")
+
+ args = parser.parse_args()
+
+ # Create TUI
+ tui = ModelCompareTUI(args.config)
+
+ # Run in selected mode
+ if args.mode == "interactive":
+ tui.run_interactive_mode()
+ else:
+ tui.run_benchmark_mode()
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/scripts/evaluate.py b/scripts/evaluate.py
new file mode 100755
index 0000000..485ca76
--- /dev/null
+++ b/scripts/evaluate.py
@@ -0,0 +1,59 @@
+#!/usr/bin/env python3
+"""
+Evaluation script for progressive model
+"""
+
+import sys
+from pathlib import Path
+
+sys.path.append(str(Path(__file__).parent.parent))
+
+from src.progressive_model import ProgressiveReasoningModel
+import yaml
+
+
+def evaluate_reasoning(model_wrapper, test_prompts):
+ """Evaluate model on test prompts"""
+ results = []
+
+ for prompt in test_prompts:
+ print(f"\nPrompt: {prompt}")
+ response = model_wrapper.generate_with_reasoning(prompt)
+ print(f"Response: {response}")
+ results.append({
+ "prompt": prompt,
+ "response": response
+ })
+
+ return results
+
+
+def main():
+ # Load config
+ with open("config/training_config.yaml") as f:
+ config = yaml.safe_load(f)
+
+ # Initialize model
+ model_wrapper = ProgressiveReasoningModel(config)
+ model_wrapper.setup_base_model()
+
+ # Test different adapters
+ test_prompts = [
+ "What is 156 + 389?",
+ "If a train travels 80 km/h for 2.5 hours, how far does it go?",
+ "Explain why the sky is blue.",
+ ]
+
+ # Test each adapter
+ for adapter_name in ["basic_cot", "math_reasoning", "complex_reasoning"]:
+ if adapter_name in model_wrapper.adapters:
+ print(f"\n{'='*50}")
+ print(f"Testing adapter: {adapter_name}")
+ print(f"{'='*50}")
+
+ model_wrapper.load_for_inference([adapter_name])
+ results = evaluate_reasoning(model_wrapper, test_prompts)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/scripts/simple_compare.py b/scripts/simple_compare.py
new file mode 100755
index 0000000..dcea3ff
--- /dev/null
+++ b/scripts/simple_compare.py
@@ -0,0 +1,189 @@
+#!/usr/bin/env python3
+"""
+Simple comparison script without rich TUI
+"""
+
+import sys
+from pathlib import Path
+import yaml
+import torch
+import argparse
+
+# Add src to path
+sys.path.append(str(Path(__file__).parent.parent))
+
+from src.progressive_model import ProgressiveReasoningModel
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Compare original and trained models")
+ parser.add_argument(
+ "--config", "-c",
+ type=str,
+ default="config/training_config_gemma2_small.yaml",
+ help="Path to configuration file"
+ )
+ parser.add_argument(
+ "--adapter", "-a",
+ type=str,
+ default="basic_cot",
+ help="Adapter name to load for comparison"
+ )
+ parser.add_argument(
+ "--max-length",
+ type=int,
+ default=512,
+ help="Maximum generation length"
+ )
+ return parser.parse_args()
+
+
+def load_config(config_path):
+ """Load configuration from file"""
+ config_path = Path(config_path)
+ if not config_path.exists():
+ raise FileNotFoundError(f"Configuration file not found: {config_path}")
+
+ with open(config_path) as f:
+ config = yaml.safe_load(f)
+ return config
+
+
+def generate_response(model, tokenizer, prompt, max_length=512):
+ """Generate response using the model"""
+ # Format prompt for Gemma
+ formatted_prompt = f"