initial
This commit is contained in:
commit
2c30e06f20
49 changed files with 3628 additions and 0 deletions
163
.devenv.flake.nix
Normal file
163
.devenv.flake.nix
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
{
|
||||
inputs =
|
||||
let
|
||||
version = "1.6.1";
|
||||
system = "x86_64-linux";
|
||||
devenv_root = "/home/centra/dev/pnn/progressive-llm-training";
|
||||
devenv_dotfile = ./.devenv;
|
||||
devenv_dotfile_string = ".devenv";
|
||||
container_name = null;
|
||||
devenv_tmpdir = "/run/user/1000";
|
||||
devenv_runtime = "/run/user/1000/devenv-adeda32";
|
||||
devenv_istesting = false;
|
||||
devenv_direnvrc_latest_version = 1;
|
||||
|
||||
in {
|
||||
git-hooks.url = "github:cachix/git-hooks.nix";
|
||||
git-hooks.inputs.nixpkgs.follows = "nixpkgs";
|
||||
pre-commit-hooks.follows = "git-hooks";
|
||||
nixpkgs.url = "github:cachix/devenv-nixpkgs/rolling";
|
||||
devenv.url = "github:cachix/devenv?dir=src/modules";
|
||||
} // (if builtins.pathExists (devenv_dotfile + "/flake.json")
|
||||
then builtins.fromJSON (builtins.readFile (devenv_dotfile + "/flake.json"))
|
||||
else { });
|
||||
|
||||
outputs = { nixpkgs, ... }@inputs:
|
||||
let
|
||||
version = "1.6.1";
|
||||
system = "x86_64-linux";
|
||||
devenv_root = "/home/centra/dev/pnn/progressive-llm-training";
|
||||
devenv_dotfile = ./.devenv;
|
||||
devenv_dotfile_string = ".devenv";
|
||||
container_name = null;
|
||||
devenv_tmpdir = "/run/user/1000";
|
||||
devenv_runtime = "/run/user/1000/devenv-adeda32";
|
||||
devenv_istesting = false;
|
||||
devenv_direnvrc_latest_version = 1;
|
||||
|
||||
devenv =
|
||||
if builtins.pathExists (devenv_dotfile + "/devenv.json")
|
||||
then builtins.fromJSON (builtins.readFile (devenv_dotfile + "/devenv.json"))
|
||||
else { };
|
||||
getOverlays = inputName: inputAttrs:
|
||||
map
|
||||
(overlay:
|
||||
let
|
||||
input = inputs.${inputName} or (throw "No such input `${inputName}` while trying to configure overlays.");
|
||||
in
|
||||
input.overlays.${overlay} or (throw "Input `${inputName}` has no overlay called `${overlay}`. Supported overlays: ${nixpkgs.lib.concatStringsSep ", " (builtins.attrNames input.overlays)}"))
|
||||
inputAttrs.overlays or [ ];
|
||||
overlays = nixpkgs.lib.flatten (nixpkgs.lib.mapAttrsToList getOverlays (devenv.inputs or { }));
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
config = {
|
||||
allowUnfree = devenv.allowUnfree or false;
|
||||
allowBroken = devenv.allowBroken or false;
|
||||
permittedInsecurePackages = devenv.permittedInsecurePackages or [ ];
|
||||
};
|
||||
inherit overlays;
|
||||
};
|
||||
lib = pkgs.lib;
|
||||
importModule = path:
|
||||
if lib.hasPrefix "./" path
|
||||
then if lib.hasSuffix ".nix" path
|
||||
then ./. + (builtins.substring 1 255 path)
|
||||
else ./. + (builtins.substring 1 255 path) + "/devenv.nix"
|
||||
else if lib.hasPrefix "../" path
|
||||
then throw "devenv: ../ is not supported for imports"
|
||||
else
|
||||
let
|
||||
paths = lib.splitString "/" path;
|
||||
name = builtins.head paths;
|
||||
input = inputs.${name} or (throw "Unknown input ${name}");
|
||||
subpath = "/${lib.concatStringsSep "/" (builtins.tail paths)}";
|
||||
devenvpath = "${input}" + subpath;
|
||||
devenvdefaultpath = devenvpath + "/devenv.nix";
|
||||
in
|
||||
if lib.hasSuffix ".nix" devenvpath
|
||||
then devenvpath
|
||||
else if builtins.pathExists devenvdefaultpath
|
||||
then devenvdefaultpath
|
||||
else throw (devenvdefaultpath + " file does not exist for input ${name}.");
|
||||
project = pkgs.lib.evalModules {
|
||||
specialArgs = inputs // { inherit inputs; };
|
||||
modules = [
|
||||
({ config, ... }: {
|
||||
_module.args.pkgs = pkgs.appendOverlays (config.overlays or [ ]);
|
||||
})
|
||||
(inputs.devenv.modules + /top-level.nix)
|
||||
{
|
||||
devenv.cliVersion = version;
|
||||
devenv.root = devenv_root;
|
||||
devenv.dotfile = devenv_root + "/" + devenv_dotfile_string;
|
||||
}
|
||||
(pkgs.lib.optionalAttrs (inputs.devenv.isTmpDir or false) {
|
||||
devenv.tmpdir = devenv_tmpdir;
|
||||
devenv.runtime = devenv_runtime;
|
||||
})
|
||||
(pkgs.lib.optionalAttrs (inputs.devenv.hasIsTesting or false) {
|
||||
devenv.isTesting = devenv_istesting;
|
||||
})
|
||||
(pkgs.lib.optionalAttrs (container_name != null) {
|
||||
container.isBuilding = pkgs.lib.mkForce true;
|
||||
containers.${container_name}.isBuilding = true;
|
||||
})
|
||||
({ options, ... }: {
|
||||
config.devenv = pkgs.lib.optionalAttrs (builtins.hasAttr "direnvrcLatestVersion" options.devenv) {
|
||||
direnvrcLatestVersion = devenv_direnvrc_latest_version;
|
||||
};
|
||||
})
|
||||
] ++ (map importModule (devenv.imports or [ ])) ++ [
|
||||
(if builtins.pathExists ./devenv.nix then ./devenv.nix else { })
|
||||
(devenv.devenv or { })
|
||||
(if builtins.pathExists ./devenv.local.nix then ./devenv.local.nix else { })
|
||||
(if builtins.pathExists (devenv_dotfile + "/cli-options.nix") then import (devenv_dotfile + "/cli-options.nix") else { })
|
||||
];
|
||||
};
|
||||
config = project.config;
|
||||
|
||||
options = pkgs.nixosOptionsDoc {
|
||||
options = builtins.removeAttrs project.options [ "_module" ];
|
||||
warningsAreErrors = false;
|
||||
# Unpack Nix types, e.g. literalExpression, mDoc.
|
||||
transformOptions =
|
||||
let isDocType = v: builtins.elem v [ "literalDocBook" "literalExpression" "literalMD" "mdDoc" ];
|
||||
in lib.attrsets.mapAttrs (_: v:
|
||||
if v ? _type && isDocType v._type then
|
||||
v.text
|
||||
else if v ? _type && v._type == "derivation" then
|
||||
v.name
|
||||
else
|
||||
v
|
||||
);
|
||||
};
|
||||
|
||||
build = options: config:
|
||||
lib.concatMapAttrs
|
||||
(name: option:
|
||||
if builtins.hasAttr "type" option then
|
||||
if option.type.name == "output" || option.type.name == "outputOf" then {
|
||||
${name} = config.${name};
|
||||
} else { }
|
||||
else
|
||||
let v = build option config.${name};
|
||||
in if v != { } then {
|
||||
${name} = v;
|
||||
} else { }
|
||||
)
|
||||
options;
|
||||
|
||||
systems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ];
|
||||
in
|
||||
{
|
||||
devShell = lib.genAttrs systems (system: config.shell);
|
||||
packages = lib.genAttrs systems (system: {
|
||||
optionsJSON = options.optionsJSON;
|
||||
# deprecated
|
||||
inherit (config) info procfileScript procfileEnv procfile;
|
||||
ci = config.ciDerivation;
|
||||
});
|
||||
devenv = config;
|
||||
build = build project.options project.config;
|
||||
};
|
||||
}
|
||||
1
.devenv/bash
Symbolic link
1
.devenv/bash
Symbolic link
|
|
@ -0,0 +1 @@
|
|||
/nix/store/94lg0shvsfc845zy8gnflvpqxxiyijbz-bash-interactive-5.2p37
|
||||
1
.devenv/devenv.json
Normal file
1
.devenv/devenv.json
Normal file
|
|
@ -0,0 +1 @@
|
|||
{"inputs":{"nixpkgs":{"url":"github:NixOS/nixpkgs/nixos-unstable"},"nixpkgs-python":{"url":"github:cachix/nixpkgs-python","inputs":{"nixpkgs":{"follows":"nixpkgs"}}}},"allowUnfree":true}
|
||||
1
.devenv/flake.json
Normal file
1
.devenv/flake.json
Normal file
|
|
@ -0,0 +1 @@
|
|||
{"nixpkgs":{"url":"github:NixOS/nixpkgs/nixos-unstable"},"nixpkgs-python":{"url":"github:cachix/nixpkgs-python","inputs":{"nixpkgs":{"follows":"nixpkgs"}}}}
|
||||
1
.devenv/gc/shell
Symbolic link
1
.devenv/gc/shell
Symbolic link
|
|
@ -0,0 +1 @@
|
|||
shell-1-link
|
||||
1
.devenv/gc/shell-1-link
Symbolic link
1
.devenv/gc/shell-1-link
Symbolic link
|
|
@ -0,0 +1 @@
|
|||
/nix/store/7fimdw1in7f1g0wxw5cr9pg26rs4rp5g-devenv-shell-env
|
||||
0
.devenv/imports.txt
Normal file
0
.devenv/imports.txt
Normal file
11
.devenv/input-paths.txt
Normal file
11
.devenv/input-paths.txt
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
/home/centra/.config/nixpkgs/config.nix
|
||||
/home/centra/.config/nixpkgs/overlays
|
||||
/home/centra/.config/nixpkgs/overlays.nix
|
||||
/home/centra/.nixpkgs/config.nix
|
||||
/home/centra/dev/pnn/progressive-llm-training/.devenv/flake.json
|
||||
/home/centra/dev/pnn/progressive-llm-training/.devenv.flake.nix
|
||||
/home/centra/dev/pnn/progressive-llm-training/.env
|
||||
/home/centra/dev/pnn/progressive-llm-training/devenv.local.nix
|
||||
/home/centra/dev/pnn/progressive-llm-training/devenv.lock
|
||||
/home/centra/dev/pnn/progressive-llm-training/devenv.nix
|
||||
/home/centra/dev/pnn/progressive-llm-training/devenv.yaml
|
||||
3
.devenv/load-exports
Executable file
3
.devenv/load-exports
Executable file
|
|
@ -0,0 +1,3 @@
|
|||
export PATH='/home/centra/dev/pnn/progressive-llm-training/.devenv/state/venv/bin:/nix/store/bdqwd2frn9m7n3hj2436s0vlnv7mawpc-python3-3.11.13-env/bin:/nix/store/9w80x8njl1hcp8vlk1f3x17q4hcd2cqp-evaluate/bin:/nix/store/8df6wqahd2fqzl04kcs3xs32yqqimcxb-install-packages/bin:/nix/store/v5rz1h6ci23icfp6y228r2m0fqrdf408-install-packages-cpu/bin:/nix/store/69142b4sjmb4jffmyjb8nar6qzlgxnpg-prepare-data/bin:/nix/store/bhb6l6yfqknnwc7y5j5xc9k866hajv7b-train/bin:/nix/store/pbqah1qk4b5y14fqinr1h8zvhqy71v81-gcc-wrapper-14.3.0/bin:/nix/store/sa7j7cddyblhcb3ch3ds10w7nw75yjj1-gcc-14.3.0/bin:/nix/store/mdmsnfcvxyk5ynz7nx8nhss1wig0gljx-glibc-2.40-66-bin/bin:/nix/store/psy9v2asypgl9ylg8cnzkixc7fv0snj0-coreutils-9.7/bin:/nix/store/cadx5p7c0i06gf6h84iw9mrhx56imbv0-binutils-wrapper-2.44/bin:/nix/store/z3za8hfc24wb117s50p8b10agjkgm039-binutils-2.44/bin:/nix/store/dx4bdrs7mq3jfviqhszrc7l35ps9kg64-cmake-3.31.7/bin:/nix/store/1492q00cm64n0hs5966s8cqj6j0x5nxg-ninja-1.12.1/bin:/nix/store/h5khrpnjj3fb182sc32fx1z75w0lhksy-pkg-config-wrapper-0.29.2/bin:/nix/store/rzqvhv48m3nh8g3j4k6jmz6yqy8apr95-git-2.49.0/bin:/nix/store/nygfbkv0j6fvwwa82mdwxm4qfiq3p4q2-git-lfs-3.6.1/bin:/nix/store/fir4g1m8dvg46mh8silh3wnmm9mc0jix-htop-3.4.1/bin:/nix/store/9mc2m4sacbk4l7sc4w7m08m1x9bf5dgn-tmux-3.5a/bin:/nix/store/cxy72qdk41k3zjs5fw1nw1whv6wf7hv2-vim-9.1.1401/bin:/nix/store/74k8qwbfa6lm8psm2vjh2vj04fpr6c5g-openssl-3.4.1-bin/bin:/nix/store/m9k83ip1yx29xs94sa5x8j70s2vfgj6i-glib-2.84.2-dev/bin:/nix/store/zs5crhr67zp8cxn7dh4mwq08zw3sb31m-gettext-0.22.5/bin:/nix/store/rklrz4rwi03hxvz0kwh75vz55wb9b1qz-glib-2.84.2-bin/bin:/nix/store/xbpwk3xzanxj12157byj6wjagm2wfb3c-cuda-merged-12.8/bin:/nix/store/v0zrnzl3anb71ma5c2kx71dl8kyh0rf6-cuda_cuobjdump-12.8.90-bin/bin:/nix/store/v4mm21f67qki6ss6mqp3anlmaiw0r1zd-pre-commit-bin/bin:/nix/store/mq2i9br9h890bnahlds9jnff1jf6xjpb-python3.13-black-25.1.0/bin:/nix/store/sd81bvmch7njdpwx3lkjslixcbj5mivz-python3-3.13.4/bin:/nix/store/mdzm1l0rnpwp8ha0mbxll0db4r2p0xj3-python3.13-flake8-7.2.0/bin:/nix/store/xs72vlx7i6snrrrqx2zn529fbbqrwlwq-python3.13-pycodestyle-2.13.0/bin:/nix/store/5a8m3p0svp6myq1cz4ww431fsbh3xrg5-python3.13-pyflakes-3.3.2/bin:/nix/store/p6bch581drrxv3dm7vwxqazpbssjz4nv-python3.13-mypy-1.15.0/bin:/nix/store/1c8sm86wj45vwkb3ww2b870h9i9wna6r-patchelf-0.15.0/bin:/nix/store/psy9v2asypgl9ylg8cnzkixc7fv0snj0-coreutils-9.7/bin:/nix/store/c14zwgl8hf1wm0izij2i16xvk8ak70cy-findutils-4.10.0/bin:/nix/store/ibx4jfwlhjg4g0s6rrxrpaxa3ka8ns4m-diffutils-3.12/bin:/nix/store/pr318zsl44jdwpk9wk0sdrn19b6in7ah-gnused-4.9/bin:/nix/store/bc6zxzjnkjp4r9nhz5imy3cypvdh6r4n-gnugrep-3.12/bin:/nix/store/nv3y7zb1cwz1h9qy7nwz0s54j8dl1kqj-gawk-5.3.2/bin:/nix/store/lp82dcnrzljyix6yigwzrlpr1smvpmb0-gnutar-1.35/bin:/nix/store/6ag5dhk7sma61p6vl0kazfmpbrq08nqh-gzip-1.14/bin:/nix/store/ykdv4id6893gmkqwdmbimq237c1xqvq7-bzip2-1.0.8-bin/bin:/nix/store/6bwp1y45zlyvpr4ja2sk1yi9v5mrs94x-gnumake-4.4.1/bin:/nix/store/00zrahbb32nzawrmv9sjxn36h7qk9vrs-bash-5.2p37/bin:/nix/store/c9xmgszbf6i4dfq9r953khk9d7fdqigw-patch-2.8/bin:/nix/store/ikfwx7kbwz9zr7fziiac7f57jgbh3bnv-xz-5.8.1-bin/bin:/nix/store/3pdmbqy86wsbjdazxv1n3vrmj60vn0ri-file-5.45/bin:/run/wrappers/bin:/home/centra/.local/share/flatpak/exports/bin:/var/lib/flatpak/exports/bin:/home/centra/.nix-profile/bin:/nix/profile/bin:/home/centra/.local/state/nix/profile/bin:/etc/profiles/per-user/centra/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin'
|
||||
export VIRTUAL_ENV=/home/centra/dev/pnn/progressive-llm-training/.devenv/state/venv
|
||||
|
||||
BIN
.devenv/nix-eval-cache.db
Normal file
BIN
.devenv/nix-eval-cache.db
Normal file
Binary file not shown.
BIN
.devenv/nix-eval-cache.db-shm
Normal file
BIN
.devenv/nix-eval-cache.db-shm
Normal file
Binary file not shown.
BIN
.devenv/nix-eval-cache.db-wal
Normal file
BIN
.devenv/nix-eval-cache.db-wal
Normal file
Binary file not shown.
1
.devenv/profile
Symbolic link
1
.devenv/profile
Symbolic link
|
|
@ -0,0 +1 @@
|
|||
/nix/store/y2vscmx3lckyzyag6xg8b02pkdsk326d-devenv-profile
|
||||
1
.devenv/run
Symbolic link
1
.devenv/run
Symbolic link
|
|
@ -0,0 +1 @@
|
|||
/run/user/1000/devenv-adeda32
|
||||
1
.devenv/state/git-hooks/config.json
Normal file
1
.devenv/state/git-hooks/config.json
Normal file
|
|
@ -0,0 +1 @@
|
|||
{configPath:.pre-commit-config.yaml}
|
||||
BIN
.devenv/tasks.db
Normal file
BIN
.devenv/tasks.db
Normal file
Binary file not shown.
BIN
.devenv/tasks.db-shm
Normal file
BIN
.devenv/tasks.db-shm
Normal file
Binary file not shown.
BIN
.devenv/tasks.db-wal
Normal file
BIN
.devenv/tasks.db-wal
Normal file
Binary file not shown.
32
.gitignore
vendored
Normal file
32
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
.venv/
|
||||
|
||||
# Nix
|
||||
result
|
||||
result-*
|
||||
|
||||
# Project specific
|
||||
outputs/
|
||||
data/
|
||||
*.log
|
||||
wandb/
|
||||
.ipynb_checkpoints/
|
||||
*.pt
|
||||
*.pth
|
||||
*.bin
|
||||
*.safetensors
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
33
=2.5.0
Normal file
33
=2.5.0
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
Collecting flash-attn
|
||||
Using cached flash_attn-2.8.0.post2-cp311-cp311-linux_x86_64.whl
|
||||
Requirement already satisfied: torch in ./.devenv/state/venv/lib/python3.11/site-packages (from flash-attn) (2.7.1+cu128)
|
||||
Collecting einops (from flash-attn)
|
||||
Using cached einops-0.8.1-py3-none-any.whl.metadata (13 kB)
|
||||
Requirement already satisfied: filelock in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (3.13.1)
|
||||
Requirement already satisfied: typing-extensions>=4.10.0 in /nix/store/x74hdbjsz4ck98w8lyxv8kkwxs1wm2il-python3.13-typing-extensions-4.13.2/lib/python3.13/site-packages (from torch->flash-attn) (4.13.2)
|
||||
Requirement already satisfied: sympy>=1.13.3 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (1.13.3)
|
||||
Requirement already satisfied: networkx in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (3.3)
|
||||
Requirement already satisfied: jinja2 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (3.1.4)
|
||||
Requirement already satisfied: fsspec in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (2024.6.1)
|
||||
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.61 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.61)
|
||||
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.57 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.57)
|
||||
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.57 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.57)
|
||||
Requirement already satisfied: nvidia-cudnn-cu12==9.7.1.26 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (9.7.1.26)
|
||||
Requirement already satisfied: nvidia-cublas-cu12==12.8.3.14 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.3.14)
|
||||
Requirement already satisfied: nvidia-cufft-cu12==11.3.3.41 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (11.3.3.41)
|
||||
Requirement already satisfied: nvidia-curand-cu12==10.3.9.55 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (10.3.9.55)
|
||||
Requirement already satisfied: nvidia-cusolver-cu12==11.7.2.55 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (11.7.2.55)
|
||||
Requirement already satisfied: nvidia-cusparse-cu12==12.5.7.53 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.5.7.53)
|
||||
Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (0.6.3)
|
||||
Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (2.26.2)
|
||||
Requirement already satisfied: nvidia-nvtx-cu12==12.8.55 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.55)
|
||||
Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.61 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.61)
|
||||
Requirement already satisfied: nvidia-cufile-cu12==1.13.0.11 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (1.13.0.11)
|
||||
Requirement already satisfied: triton==3.3.1 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (3.3.1)
|
||||
Requirement already satisfied: setuptools>=40.8.0 in ./.devenv/state/venv/lib/python3.11/site-packages (from triton==3.3.1->torch->flash-attn) (80.9.0)
|
||||
Requirement already satisfied: mpmath<1.4,>=1.1.0 in ./.devenv/state/venv/lib/python3.11/site-packages (from sympy>=1.13.3->torch->flash-attn) (1.3.0)
|
||||
Requirement already satisfied: MarkupSafe>=2.0 in ./.devenv/state/venv/lib/python3.11/site-packages (from jinja2->torch->flash-attn) (2.1.5)
|
||||
Using cached einops-0.8.1-py3-none-any.whl (64 kB)
|
||||
Installing collected packages: einops, flash-attn
|
||||
|
||||
Successfully installed einops-0.8.1 flash-attn-2.8.0.post2
|
||||
124
LORA_TARGET_MODULES.md
Normal file
124
LORA_TARGET_MODULES.md
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
# LoRA Target Modules Reference
|
||||
|
||||
This document provides the correct target module names for different model architectures when using LoRA (Low-Rank Adaptation).
|
||||
|
||||
## Model Architecture Detection
|
||||
|
||||
Use the inspection script to find correct target modules:
|
||||
|
||||
```bash
|
||||
# In the nix development environment
|
||||
python /home/centra/dev/pnn/inspect_conv1d_model.py [model_name]
|
||||
```
|
||||
|
||||
## Common Model Architectures
|
||||
|
||||
### GPT-2 / DialoGPT Models
|
||||
- **Model Type**: GPT2LMHeadModel
|
||||
- **Layer Type**: Conv1D (not Linear!)
|
||||
- **Base Model**: microsoft/DialoGPT-small, gpt2, gpt2-medium, gpt2-large, gpt2-xl
|
||||
|
||||
#### Attention Modules
|
||||
- `c_attn` - Combined query, key, value projection (nf=3*hidden_size)
|
||||
- `c_proj` - Output projection
|
||||
|
||||
#### MLP Modules
|
||||
- `mlp.c_fc` - Feed-forward up projection
|
||||
- `mlp.c_proj` - Feed-forward down projection
|
||||
|
||||
#### Recommended Configurations
|
||||
```yaml
|
||||
# Basic stage (attention only)
|
||||
target_modules: ["c_attn", "c_proj"]
|
||||
|
||||
# Advanced stage (attention + MLP)
|
||||
target_modules: ["c_attn", "c_proj", "mlp.c_fc", "mlp.c_proj"]
|
||||
```
|
||||
|
||||
### LLaMA Models
|
||||
- **Model Type**: LlamaForCausalLM
|
||||
- **Layer Type**: Linear
|
||||
- **Base Model**: meta-llama/Llama-2-7b-hf, meta-llama/Llama-3.2-8B
|
||||
|
||||
#### Attention Modules
|
||||
- `q_proj` - Query projection
|
||||
- `k_proj` - Key projection
|
||||
- `v_proj` - Value projection
|
||||
- `o_proj` - Output projection
|
||||
|
||||
#### MLP Modules
|
||||
- `gate_proj` - Gate projection
|
||||
- `up_proj` - Up projection
|
||||
- `down_proj` - Down projection
|
||||
|
||||
#### Recommended Configurations
|
||||
```yaml
|
||||
# Basic stage (attention only)
|
||||
target_modules: ["q_proj", "v_proj"]
|
||||
|
||||
# Advanced stage (attention + MLP)
|
||||
target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
```
|
||||
|
||||
### Mistral Models
|
||||
- **Model Type**: MistralForCausalLM
|
||||
- **Layer Type**: Linear
|
||||
- **Base Model**: mistralai/Mistral-7B-v0.1
|
||||
|
||||
#### Target Modules (same as LLaMA)
|
||||
```yaml
|
||||
target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
```
|
||||
|
||||
### Qwen Models
|
||||
- **Model Type**: QWenLMHeadModel
|
||||
- **Layer Type**: Linear
|
||||
- **Base Model**: Qwen/Qwen-7B
|
||||
|
||||
#### Target Modules
|
||||
```yaml
|
||||
target_modules: ["c_attn", "c_proj", "w1", "w2"]
|
||||
```
|
||||
|
||||
## Important Notes
|
||||
|
||||
1. **Conv1D vs Linear**: GPT-2 based models use `Conv1D` layers, not `Linear` layers
|
||||
2. **Module Patterns**: Use simple patterns like `"c_attn"` rather than full paths like `"transformer.h.0.attn.c_attn"`
|
||||
3. **Testing**: Always test your configuration before training by creating a PEFT model
|
||||
4. **Architecture Variations**: Different model families use different naming conventions
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Error: "Target module not found"
|
||||
- Run the inspection script to find actual module names
|
||||
- Check if the model uses Conv1D or Linear layers
|
||||
- Verify the module naming pattern for your specific model
|
||||
|
||||
### Error: "No trainable parameters"
|
||||
- Ensure target modules exist in the model
|
||||
- Check that the module names match exactly
|
||||
- Verify the model architecture is supported by PEFT
|
||||
|
||||
## Testing Your Configuration
|
||||
|
||||
```python
|
||||
from peft import get_peft_model, LoraConfig, TaskType
|
||||
|
||||
# Test configuration
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
target_modules=["c_attn", "c_proj"], # Your target modules
|
||||
bias="none"
|
||||
)
|
||||
|
||||
# Try to create PEFT model
|
||||
try:
|
||||
peft_model = get_peft_model(model, lora_config)
|
||||
peft_model.print_trainable_parameters()
|
||||
print("✓ Configuration works!")
|
||||
except Exception as e:
|
||||
print(f"✗ Configuration failed: {e}")
|
||||
```
|
||||
85
config/README.md
Normal file
85
config/README.md
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
# Training Configuration Files
|
||||
|
||||
This directory contains configuration files for different model sizes and use cases.
|
||||
|
||||
## Available Configurations
|
||||
|
||||
### Small Models (Testing)
|
||||
- `training_config.yaml` - Default configuration for small models (DialoGPT-small)
|
||||
- Memory: ~1GB VRAM
|
||||
- Batch size: 8
|
||||
- No quantization
|
||||
|
||||
### Medium Models (8B)
|
||||
- `training_config_large.yaml` - Configuration for 8B models (Llama-3.2-8B)
|
||||
- Memory: ~12GB VRAM with 4-bit quantization
|
||||
- Batch size: 1, gradient accumulation: 16-64
|
||||
- 4-bit quantization enabled
|
||||
|
||||
### Large Models (13B)
|
||||
- `training_config_13b.yaml` - Configuration for 13B models
|
||||
- Memory: ~16GB VRAM with 4-bit quantization
|
||||
- Batch size: 1, gradient accumulation: 32-128
|
||||
- Higher LoRA ranks (32-128)
|
||||
|
||||
### Extra Large Models (70B)
|
||||
- `training_config_70b.yaml` - Configuration for 70B models
|
||||
- Memory: ~40GB+ VRAM with 4-bit quantization
|
||||
- Batch size: 1, gradient accumulation: 64-256
|
||||
- Maximum LoRA ranks (64-256)
|
||||
- Multi-GPU support with FSDP
|
||||
|
||||
## Configuration Parameters
|
||||
|
||||
### Model Settings
|
||||
- `load_in_4bit`: Enable 4-bit quantization (recommended for large models)
|
||||
- `gradient_checkpointing`: Trade compute for memory
|
||||
- `use_flash_attention_2`: Faster attention computation if available
|
||||
|
||||
### Adapter Settings
|
||||
- `r`: LoRA rank (higher = more parameters but better capacity)
|
||||
- `lora_alpha`: LoRA scaling factor (typically 2x the rank)
|
||||
- `init_lora_weights`: Set to `true` for identity initialization
|
||||
|
||||
### Training Settings
|
||||
- `per_device_batch_size`: Usually 1 for large models
|
||||
- `gradient_accumulation_steps`: Effective batch size multiplier
|
||||
- `learning_rate`: Lower for larger models
|
||||
- `bf16`: Use bfloat16 for better numerical stability
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# For 8B models
|
||||
python scripts/train_progressive.py --config config/training_config_large.yaml
|
||||
|
||||
# For 13B models
|
||||
python scripts/train_progressive.py --config config/training_config_13b.yaml
|
||||
|
||||
# For 70B models (requires multiple GPUs)
|
||||
python scripts/train_progressive.py --config config/training_config_70b.yaml
|
||||
```
|
||||
|
||||
## Memory Requirements
|
||||
|
||||
| Model Size | VRAM (4-bit) | VRAM (16-bit) | GPUs Recommended |
|
||||
|------------|--------------|---------------|------------------|
|
||||
| 8B | 12-16GB | 32GB | 1x RTX 4090 |
|
||||
| 13B | 16-20GB | 52GB | 1x A100 |
|
||||
| 70B | 40-60GB | 140GB | 2x A100 |
|
||||
|
||||
## Tips for Large Models
|
||||
|
||||
1. **Start with smaller models** to validate your approach
|
||||
2. **Use gradient checkpointing** to reduce memory usage
|
||||
3. **Monitor GPU memory** during training
|
||||
4. **Use lower learning rates** for stability
|
||||
5. **Consider multi-GPU setup** for 70B+ models
|
||||
6. **Enable flash attention** if available for speed
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **OOM errors**: Reduce batch size or enable gradient checkpointing
|
||||
- **Slow training**: Enable flash attention, use bf16
|
||||
- **Poor convergence**: Adjust learning rate or warmup steps
|
||||
- **Multi-GPU issues**: Check FSDP configuration
|
||||
36
config/training_config.yaml
Normal file
36
config/training_config.yaml
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
experiment:
|
||||
name: "progressive_reasoning_experiment"
|
||||
base_model: "microsoft/DialoGPT-small" # Lightweight model for testing
|
||||
output_dir: "./outputs"
|
||||
use_wandb: false
|
||||
wandb_project: "matsuo-llm-comp-2025"
|
||||
|
||||
model:
|
||||
load_in_4bit: false # Disable quantization for small model
|
||||
bnb_4bit_compute_dtype: "bfloat16"
|
||||
bnb_4bit_use_double_quant: true
|
||||
device_map: "auto"
|
||||
|
||||
progressive_stages:
|
||||
- name: "basic_cot"
|
||||
description: "Basic Chain-of-Thought reasoning"
|
||||
dataset_path: "./data/basic_cot/"
|
||||
adapter_config:
|
||||
r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.1
|
||||
target_modules: ["c_attn", "c_proj"]
|
||||
training:
|
||||
num_epochs: 2
|
||||
per_device_batch_size: 8 # Increase batch size for small model
|
||||
gradient_accumulation_steps: 2 # Reduce accumulation steps
|
||||
learning_rate: 5e-4 # Higher learning rate for faster training
|
||||
warmup_steps: 50
|
||||
max_length: 1024 # Shorter sequences
|
||||
|
||||
evaluation:
|
||||
benchmarks:
|
||||
- "HLE" # Humanity's Last Exam
|
||||
- "Do-Not-Answer"
|
||||
save_results: true
|
||||
results_dir: "./outputs/evaluation_results"
|
||||
83
config/training_config_13b.yaml
Normal file
83
config/training_config_13b.yaml
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
experiment:
|
||||
name: "progressive_reasoning_13b"
|
||||
base_model: "meta-llama/Llama-3.2-13B" # 13B model
|
||||
output_dir: "./outputs"
|
||||
use_wandb: true
|
||||
wandb_project: "matsuo-llm-comp-2025"
|
||||
|
||||
model:
|
||||
load_in_4bit: true
|
||||
bnb_4bit_compute_dtype: "bfloat16"
|
||||
bnb_4bit_use_double_quant: true
|
||||
bnb_4bit_quant_type: "nf4"
|
||||
device_map: "auto"
|
||||
gradient_checkpointing: true
|
||||
use_flash_attention_2: true
|
||||
|
||||
progressive_stages:
|
||||
- name: "basic_cot"
|
||||
description: "Basic Chain-of-Thought reasoning"
|
||||
dataset_path: "./data/basic_cot/"
|
||||
adapter_config:
|
||||
r: 32 # Higher rank for 13B models
|
||||
lora_alpha: 64
|
||||
lora_dropout: 0.05
|
||||
target_modules: ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 1
|
||||
per_device_batch_size: 1
|
||||
gradient_accumulation_steps: 32
|
||||
learning_rate: 1e-4
|
||||
warmup_steps: 100
|
||||
max_length: 2048
|
||||
bf16: true
|
||||
max_grad_norm: 0.3
|
||||
weight_decay: 0.001
|
||||
|
||||
- name: "math_reasoning"
|
||||
description: "Mathematical reasoning with think tags"
|
||||
dataset_path: "./data/math_reasoning/"
|
||||
inherit_from: "basic_cot"
|
||||
adapter_config:
|
||||
r: 64
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 2
|
||||
per_device_batch_size: 1
|
||||
gradient_accumulation_steps: 64
|
||||
learning_rate: 8e-5
|
||||
warmup_steps: 200
|
||||
max_length: 4096
|
||||
bf16: true
|
||||
max_grad_norm: 0.3
|
||||
|
||||
- name: "complex_reasoning"
|
||||
description: "Complex multi-step reasoning"
|
||||
dataset_path: "./data/complex_reasoning/"
|
||||
inherit_from: "math_reasoning"
|
||||
adapter_config:
|
||||
r: 128 # Maximum rank for 13B models
|
||||
lora_alpha: 256
|
||||
lora_dropout: 0.05
|
||||
target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 1
|
||||
per_device_batch_size: 1
|
||||
gradient_accumulation_steps: 128
|
||||
learning_rate: 5e-5
|
||||
warmup_steps: 300
|
||||
max_length: 8192
|
||||
bf16: true
|
||||
max_grad_norm: 0.3
|
||||
|
||||
evaluation:
|
||||
benchmarks:
|
||||
- "HLE"
|
||||
- "Do-Not-Answer"
|
||||
save_results: true
|
||||
results_dir: "./outputs/evaluation_results"
|
||||
101
config/training_config_70b.yaml
Normal file
101
config/training_config_70b.yaml
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
experiment:
|
||||
name: "progressive_reasoning_70b"
|
||||
base_model: "meta-llama/Llama-3.2-70B" # 70B model - requires significant resources
|
||||
output_dir: "./outputs"
|
||||
use_wandb: true
|
||||
wandb_project: "matsuo-llm-comp-2025"
|
||||
|
||||
model:
|
||||
load_in_4bit: true
|
||||
bnb_4bit_compute_dtype: "bfloat16"
|
||||
bnb_4bit_use_double_quant: true
|
||||
bnb_4bit_quant_type: "nf4"
|
||||
device_map: "auto"
|
||||
gradient_checkpointing: true
|
||||
use_flash_attention_2: true
|
||||
|
||||
progressive_stages:
|
||||
- name: "basic_cot"
|
||||
description: "Basic Chain-of-Thought reasoning"
|
||||
dataset_path: "./data/basic_cot/"
|
||||
adapter_config:
|
||||
r: 64 # Even higher rank for 70B models
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
target_modules: ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 1
|
||||
per_device_batch_size: 1
|
||||
gradient_accumulation_steps: 64
|
||||
learning_rate: 5e-5 # Lower learning rate for stability
|
||||
warmup_steps: 200
|
||||
max_length: 2048
|
||||
bf16: true
|
||||
max_grad_norm: 0.3
|
||||
weight_decay: 0.001
|
||||
dataloader_num_workers: 2
|
||||
|
||||
- name: "math_reasoning"
|
||||
description: "Mathematical reasoning with think tags"
|
||||
dataset_path: "./data/math_reasoning/"
|
||||
inherit_from: "basic_cot"
|
||||
adapter_config:
|
||||
r: 128
|
||||
lora_alpha: 256
|
||||
lora_dropout: 0.05
|
||||
target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 1
|
||||
per_device_batch_size: 1
|
||||
gradient_accumulation_steps: 128
|
||||
learning_rate: 3e-5
|
||||
warmup_steps: 300
|
||||
max_length: 4096
|
||||
bf16: true
|
||||
max_grad_norm: 0.3
|
||||
dataloader_num_workers: 2
|
||||
|
||||
- name: "complex_reasoning"
|
||||
description: "Complex multi-step reasoning"
|
||||
dataset_path: "./data/complex_reasoning/"
|
||||
inherit_from: "math_reasoning"
|
||||
adapter_config:
|
||||
r: 256 # Maximum rank for 70B models
|
||||
lora_alpha: 512
|
||||
lora_dropout: 0.05
|
||||
target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 1
|
||||
per_device_batch_size: 1
|
||||
gradient_accumulation_steps: 256
|
||||
learning_rate: 2e-5
|
||||
warmup_steps: 500
|
||||
max_length: 8192
|
||||
bf16: true
|
||||
max_grad_norm: 0.3
|
||||
dataloader_num_workers: 2
|
||||
|
||||
evaluation:
|
||||
benchmarks:
|
||||
- "HLE"
|
||||
- "Do-Not-Answer"
|
||||
save_results: true
|
||||
results_dir: "./outputs/evaluation_results"
|
||||
|
||||
# Additional settings for 70B models
|
||||
optimization:
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
ddp_find_unused_parameters: false
|
||||
# Multi-GPU settings
|
||||
fsdp: "full_shard auto_wrap"
|
||||
fsdp_transformer_layer_cls_to_wrap: "LlamaDecoderLayer"
|
||||
fsdp_min_num_params: 1000000
|
||||
fsdp_config:
|
||||
min_num_params: 1000000
|
||||
sharding_strategy: "FULL_SHARD"
|
||||
cpu_offload: false
|
||||
91
config/training_config_gemma2_small.yaml
Normal file
91
config/training_config_gemma2_small.yaml
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
experiment:
|
||||
name: "progressive_reasoning_gemma2_small"
|
||||
base_model: "google/gemma-2-2b-it" # Instruction-tuned version
|
||||
output_dir: "./outputs"
|
||||
use_wandb: true
|
||||
wandb_project: "matsuo-llm-comp-2025"
|
||||
|
||||
model:
|
||||
load_in_4bit: false # 2B model is manageable without quantization
|
||||
bnb_4bit_compute_dtype: "bfloat16"
|
||||
bnb_4bit_use_double_quant: true
|
||||
device_map: "auto"
|
||||
gradient_checkpointing: false
|
||||
use_flash_attention_2: false
|
||||
use_eager_attention: true # Required for Gemma 3 models
|
||||
|
||||
progressive_stages:
|
||||
- name: "basic_cot"
|
||||
description: "Basic Chain-of-Thought reasoning"
|
||||
dataset_path: "./data/basic_cot/"
|
||||
adapter_config:
|
||||
r: 8 # Start with smaller rank for small model
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.1
|
||||
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 3
|
||||
per_device_batch_size: 8 # Larger batch size for small model
|
||||
gradient_accumulation_steps: 2
|
||||
learning_rate: 5e-4 # Higher learning rate for small model
|
||||
warmup_steps: 50
|
||||
max_length: 1024
|
||||
bf16: true
|
||||
max_grad_norm: 1.0
|
||||
weight_decay: 0.001
|
||||
save_steps: 50
|
||||
logging_steps: 10
|
||||
|
||||
- name: "math_reasoning"
|
||||
description: "Mathematical reasoning with think tags"
|
||||
dataset_path: "./data/math_reasoning/"
|
||||
inherit_from: "basic_cot"
|
||||
adapter_config:
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 3
|
||||
per_device_batch_size: 4
|
||||
gradient_accumulation_steps: 4
|
||||
learning_rate: 3e-4
|
||||
warmup_steps: 100
|
||||
max_length: 2048
|
||||
bf16: true
|
||||
max_grad_norm: 1.0
|
||||
|
||||
- name: "complex_reasoning"
|
||||
description: "Complex multi-step reasoning with Mixture-of-Thoughts"
|
||||
dataset_path: "open-r1/Mixture-of-Thoughts" # HuggingFace dataset
|
||||
inherit_from: "math_reasoning"
|
||||
adapter_config:
|
||||
r: 32
|
||||
lora_alpha: 64
|
||||
lora_dropout: 0.1
|
||||
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 1 # Large dataset, fewer epochs
|
||||
per_device_batch_size: 2
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 2e-4
|
||||
warmup_steps: 200
|
||||
max_length: 4096
|
||||
bf16: true
|
||||
max_grad_norm: 1.0
|
||||
save_steps: 500
|
||||
logging_steps: 50
|
||||
dataset_config:
|
||||
streaming: true
|
||||
max_samples: 30000
|
||||
split: "train"
|
||||
|
||||
evaluation:
|
||||
benchmarks:
|
||||
- "HLE"
|
||||
- "Do-Not-Answer"
|
||||
save_results: true
|
||||
results_dir: "./outputs/evaluation_results"
|
||||
102
config/training_config_gemma3_1b.yaml
Normal file
102
config/training_config_gemma3_1b.yaml
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
experiment:
|
||||
name: "progressive_reasoning_gemma3_1b"
|
||||
base_model: "google/gemma-3-1b-pt" # Using Gemma 2 2B (1B might not be available)
|
||||
output_dir: "./outputs"
|
||||
use_wandb: true
|
||||
wandb_project: "matsuo-llm-comp-2025"
|
||||
|
||||
model:
|
||||
load_in_4bit: false
|
||||
bnb_4bit_compute_dtype: "bfloat16"
|
||||
bnb_4bit_use_double_quant: true
|
||||
device_map: "auto"
|
||||
gradient_checkpointing: false # Not needed for small models
|
||||
use_flash_attention_2: false
|
||||
use_eager_attention: true
|
||||
|
||||
progressive_stages:
|
||||
- name: "basic_cot"
|
||||
description: "Basic Chain-of-Thought reasoning"
|
||||
dataset_path: "./data/basic_cot/"
|
||||
adapter_config:
|
||||
r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.1
|
||||
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] # Gemma attention modules
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 2
|
||||
per_device_batch_size: 8
|
||||
gradient_accumulation_steps: 2
|
||||
learning_rate: 5e-4
|
||||
warmup_steps: 50
|
||||
max_length: 1024
|
||||
fp16: false
|
||||
bf16: true
|
||||
max_grad_norm: 1.0
|
||||
weight_decay: 0.001
|
||||
save_steps: 100
|
||||
logging_steps: 10
|
||||
|
||||
- name: "math_reasoning"
|
||||
description: "Mathematical reasoning with OpenR1-Math-220k dataset"
|
||||
dataset_path: "open-r1/OpenR1-Math-220k" # HuggingFace dataset
|
||||
inherit_from: "basic_cot"
|
||||
adapter_config:
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 1 # Large dataset, fewer epochs
|
||||
per_device_batch_size: 4
|
||||
gradient_accumulation_steps: 4
|
||||
learning_rate: 3e-4
|
||||
warmup_steps: 100
|
||||
max_length: 2048
|
||||
bf16: true
|
||||
max_grad_norm: 1.0
|
||||
weight_decay: 0.001
|
||||
save_steps: 1000
|
||||
logging_steps: 100
|
||||
dataset_config:
|
||||
# OpenR1-Math-220k specific settings
|
||||
streaming: true # Use streaming for large dataset
|
||||
max_samples: 200000 # Limit samples for faster training
|
||||
split: "train"
|
||||
|
||||
- name: "complex_reasoning"
|
||||
description: "Complex multi-step reasoning with Mixture-of-Thoughts"
|
||||
dataset_path: "open-r1/Mixture-of-Thoughts" # HuggingFace dataset
|
||||
inherit_from: "math_reasoning"
|
||||
adapter_config:
|
||||
r: 32
|
||||
lora_alpha: 64
|
||||
lora_dropout: 0.1
|
||||
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 1 # Large dataset, fewer epochs
|
||||
per_device_batch_size: 2
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 2e-4
|
||||
warmup_steps: 200
|
||||
max_length: 4096
|
||||
bf16: true
|
||||
max_grad_norm: 1.0
|
||||
weight_decay: 0.001
|
||||
save_steps: 500
|
||||
logging_steps: 50
|
||||
dataset_config:
|
||||
# Mixture-of-Thoughts specific settings
|
||||
streaming: true # Use streaming for large dataset
|
||||
max_samples: 30000 # Limit samples for faster training
|
||||
split: "train"
|
||||
|
||||
evaluation:
|
||||
benchmarks:
|
||||
- "HLE"
|
||||
- "Do-Not-Answer"
|
||||
save_results: true
|
||||
results_dir: "./outputs/evaluation_results"
|
||||
133
config/training_config_gemma3_1b_cpu_offload.yaml
Normal file
133
config/training_config_gemma3_1b_cpu_offload.yaml
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
experiment:
|
||||
name: "progressive_reasoning_gemma3_1b_cpu_offload"
|
||||
base_model: "google/gemma-3-1b-pt" # Using Gemma 3 1B
|
||||
output_dir: "./outputs"
|
||||
use_wandb: true
|
||||
wandb_project: "matsuo-llm-comp-2025"
|
||||
|
||||
model:
|
||||
load_in_4bit: true # Enable 4-bit quantization for QLoRA
|
||||
bnb_4bit_compute_dtype: "bfloat16"
|
||||
bnb_4bit_use_double_quant: true
|
||||
bnb_4bit_quant_type: "nf4"
|
||||
device_map: "auto" # Let accelerate handle device placement
|
||||
max_memory:
|
||||
0: "5GB" # Limit GPU memory to 3GB (leave room for CUDA kernels)
|
||||
"cpu": "32GB" # Allow up to 32GB CPU RAM
|
||||
offload_folder: "./offload" # Directory for disk offloading if needed
|
||||
gradient_checkpointing: true # Trade compute for memory
|
||||
use_flash_attention_2: false
|
||||
use_eager_attention: true
|
||||
|
||||
progressive_stages:
|
||||
- name: "basic_cot"
|
||||
description: "Basic Chain-of-Thought reasoning"
|
||||
dataset_path: "./data/basic_cot/"
|
||||
adapter_config:
|
||||
r: 8 # Lower rank for memory efficiency
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.1
|
||||
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 2
|
||||
per_device_batch_size: 2 # Smaller batch size
|
||||
gradient_accumulation_steps: 8 # Compensate with gradient accumulation
|
||||
learning_rate: 5e-4
|
||||
warmup_steps: 50
|
||||
max_length: 512 # Shorter sequences for memory
|
||||
bf16: true
|
||||
max_grad_norm: 1.0
|
||||
weight_decay: 0.001
|
||||
save_steps: 100
|
||||
logging_steps: 10
|
||||
dataloader_num_workers: 0 # Disable multiprocessing to save memory
|
||||
optim: "paged_adamw_8bit" # Use 8-bit optimizer
|
||||
|
||||
- name: "math_reasoning"
|
||||
description: "Mathematical reasoning with OpenR1-Math-220k dataset"
|
||||
dataset_path: "open-r1/OpenR1-Math-220k"
|
||||
inherit_from: "basic_cot"
|
||||
adapter_config:
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 1
|
||||
per_device_batch_size: 1 # Minimal batch size
|
||||
gradient_accumulation_steps: 16
|
||||
learning_rate: 3e-4
|
||||
warmup_steps: 100
|
||||
max_length: 1024
|
||||
bf16: true
|
||||
max_grad_norm: 1.0
|
||||
weight_decay: 0.001
|
||||
save_steps: 1000
|
||||
logging_steps: 100
|
||||
optim: "paged_adamw_8bit"
|
||||
dataset_config:
|
||||
streaming: true
|
||||
max_samples: 200000 # Reduced for testing
|
||||
split: "train"
|
||||
|
||||
- name: "complex_reasoning"
|
||||
description: "Complex multi-step reasoning with Mixture-of-Thoughts"
|
||||
dataset_path: "open-r1/Mixture-of-Thoughts" # HuggingFace dataset
|
||||
inherit_from: "math_reasoning"
|
||||
adapter_config:
|
||||
r: 32
|
||||
lora_alpha: 64
|
||||
lora_dropout: 0.1
|
||||
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 1
|
||||
per_device_batch_size: 1
|
||||
gradient_accumulation_steps: 32
|
||||
learning_rate: 2e-4
|
||||
warmup_steps: 200
|
||||
max_length: 2048
|
||||
bf16: true
|
||||
max_grad_norm: 1.0
|
||||
weight_decay: 0.001
|
||||
optim: "paged_adamw_8bit"
|
||||
save_steps: 500
|
||||
logging_steps: 50
|
||||
dataset_config:
|
||||
streaming: true
|
||||
max_samples: 300000 # Limited for CPU offload config
|
||||
split: "train"
|
||||
|
||||
evaluation:
|
||||
benchmarks:
|
||||
- "HLE"
|
||||
- "Do-Not-Answer"
|
||||
save_results: true
|
||||
results_dir: "./outputs/evaluation_results"
|
||||
|
||||
# DeepSpeed configuration for advanced CPU offloading (optional)
|
||||
# Uncomment to use DeepSpeed ZeRO-2 with CPU offload
|
||||
# deepspeed:
|
||||
# zero_optimization:
|
||||
# stage: 2
|
||||
# offload_optimizer:
|
||||
# device: "cpu"
|
||||
# pin_memory: true
|
||||
# offload_param:
|
||||
# device: "cpu"
|
||||
# pin_memory: true
|
||||
# overlap_comm: true
|
||||
# contiguous_gradients: true
|
||||
# sub_group_size: 1e9
|
||||
# reduce_bucket_size: 1e6
|
||||
|
||||
# FSDP configuration for distributed training (optional)
|
||||
# Uncomment to use FSDP with CPU offload
|
||||
# fsdp:
|
||||
# sharding_strategy: "FULL_SHARD"
|
||||
# cpu_offload: true
|
||||
# auto_wrap_policy: "TRANSFORMER_BASED_WRAP"
|
||||
# transformer_layer_cls_to_wrap: "GemmaDecoderLayer"
|
||||
# min_num_params: 1e6
|
||||
98
config/training_config_large.yaml
Normal file
98
config/training_config_large.yaml
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
experiment:
|
||||
name: "progressive_reasoning_large_model"
|
||||
base_model: "meta-llama/Llama-3.2-8B" # Or other whitelisted models
|
||||
output_dir: "./outputs"
|
||||
use_wandb: true
|
||||
wandb_project: "matsuo-llm-comp-2025"
|
||||
|
||||
model:
|
||||
load_in_4bit: true # Enable 4-bit quantization for memory efficiency
|
||||
bnb_4bit_compute_dtype: "bfloat16"
|
||||
bnb_4bit_use_double_quant: true
|
||||
bnb_4bit_quant_type: "nf4"
|
||||
device_map: "auto"
|
||||
# Additional memory optimizations
|
||||
gradient_checkpointing: true
|
||||
use_flash_attention_2: true # If available
|
||||
|
||||
progressive_stages:
|
||||
- name: "basic_cot"
|
||||
description: "Basic Chain-of-Thought reasoning"
|
||||
dataset_path: "./data/basic_cot/"
|
||||
adapter_config:
|
||||
r: 16 # Larger rank for bigger models
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
target_modules: ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
init_lora_weights: true # Identity initialization
|
||||
training:
|
||||
num_epochs: 1
|
||||
per_device_batch_size: 1 # Small batch size for large models
|
||||
gradient_accumulation_steps: 16 # Effective batch size = 16
|
||||
learning_rate: 2e-4
|
||||
warmup_steps: 100
|
||||
max_length: 2048
|
||||
fp16: false
|
||||
bf16: true
|
||||
max_grad_norm: 0.3
|
||||
weight_decay: 0.001
|
||||
save_steps: 50
|
||||
logging_steps: 10
|
||||
|
||||
- name: "math_reasoning"
|
||||
description: "Mathematical reasoning with think tags"
|
||||
dataset_path: "./data/math_reasoning/"
|
||||
inherit_from: "basic_cot"
|
||||
adapter_config:
|
||||
r: 32 # Increase rank for more complex tasks
|
||||
lora_alpha: 64
|
||||
lora_dropout: 0.05
|
||||
target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 2
|
||||
per_device_batch_size: 1
|
||||
gradient_accumulation_steps: 32 # Effective batch size = 32
|
||||
learning_rate: 1e-4
|
||||
warmup_steps: 200
|
||||
max_length: 4096
|
||||
bf16: true
|
||||
max_grad_norm: 0.3
|
||||
weight_decay: 0.001
|
||||
|
||||
- name: "complex_reasoning"
|
||||
description: "Complex multi-step reasoning"
|
||||
dataset_path: "./data/complex_reasoning/"
|
||||
inherit_from: "math_reasoning"
|
||||
adapter_config:
|
||||
r: 64 # Maximum rank for most complex tasks
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 2
|
||||
per_device_batch_size: 1
|
||||
gradient_accumulation_steps: 64 # Effective batch size = 64
|
||||
learning_rate: 5e-5
|
||||
warmup_steps: 300
|
||||
max_length: 8192
|
||||
bf16: true
|
||||
max_grad_norm: 0.3
|
||||
weight_decay: 0.001
|
||||
|
||||
evaluation:
|
||||
benchmarks:
|
||||
- "HLE"
|
||||
- "Do-Not-Answer"
|
||||
save_results: true
|
||||
results_dir: "./outputs/evaluation_results"
|
||||
|
||||
# Memory optimization settings
|
||||
optimization:
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
ddp_find_unused_parameters: false
|
||||
fsdp: "full_shard auto_wrap" # For multi-GPU setups
|
||||
fsdp_transformer_layer_cls_to_wrap: "LlamaDecoderLayer"
|
||||
85
config/training_config_llama_auth.yaml
Normal file
85
config/training_config_llama_auth.yaml
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
experiment:
|
||||
name: "progressive_reasoning_llama_auth"
|
||||
base_model: "meta-llama/Llama-3.2-8B"
|
||||
output_dir: "./outputs"
|
||||
use_wandb: true
|
||||
wandb_project: "matsuo-llm-comp-2025"
|
||||
|
||||
model:
|
||||
load_in_4bit: true
|
||||
bnb_4bit_compute_dtype: "bfloat16"
|
||||
bnb_4bit_use_double_quant: true
|
||||
bnb_4bit_quant_type: "nf4"
|
||||
device_map: "auto"
|
||||
gradient_checkpointing: true
|
||||
use_flash_attention_2: true
|
||||
# Add your HuggingFace token here, or set HF_TOKEN environment variable
|
||||
# hf_token: "your_token_here"
|
||||
|
||||
progressive_stages:
|
||||
- name: "basic_cot"
|
||||
description: "Basic Chain-of-Thought reasoning"
|
||||
dataset_path: "./data/basic_cot/"
|
||||
adapter_config:
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
target_modules: ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 1
|
||||
per_device_batch_size: 1
|
||||
gradient_accumulation_steps: 16
|
||||
learning_rate: 2e-4
|
||||
warmup_steps: 100
|
||||
max_length: 2048
|
||||
bf16: true
|
||||
max_grad_norm: 0.3
|
||||
weight_decay: 0.001
|
||||
|
||||
- name: "math_reasoning"
|
||||
description: "Mathematical reasoning with think tags"
|
||||
dataset_path: "./data/math_reasoning/"
|
||||
inherit_from: "basic_cot"
|
||||
adapter_config:
|
||||
r: 32
|
||||
lora_alpha: 64
|
||||
lora_dropout: 0.05
|
||||
target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 2
|
||||
per_device_batch_size: 1
|
||||
gradient_accumulation_steps: 32
|
||||
learning_rate: 1e-4
|
||||
warmup_steps: 200
|
||||
max_length: 4096
|
||||
bf16: true
|
||||
max_grad_norm: 0.3
|
||||
|
||||
- name: "complex_reasoning"
|
||||
description: "Complex multi-step reasoning"
|
||||
dataset_path: "./data/complex_reasoning/"
|
||||
inherit_from: "math_reasoning"
|
||||
adapter_config:
|
||||
r: 64
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 2
|
||||
per_device_batch_size: 1
|
||||
gradient_accumulation_steps: 64
|
||||
learning_rate: 5e-5
|
||||
warmup_steps: 300
|
||||
max_length: 8192
|
||||
bf16: true
|
||||
max_grad_norm: 0.3
|
||||
|
||||
evaluation:
|
||||
benchmarks:
|
||||
- "HLE"
|
||||
- "Do-Not-Answer"
|
||||
save_results: true
|
||||
results_dir: "./outputs/evaluation_results"
|
||||
82
config/training_config_public.yaml
Normal file
82
config/training_config_public.yaml
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
experiment:
|
||||
name: "progressive_reasoning_public_model"
|
||||
base_model: "microsoft/DialoGPT-medium" # Public model, no authentication needed
|
||||
output_dir: "./outputs"
|
||||
use_wandb: false
|
||||
wandb_project: "matsuo-llm-comp-2025"
|
||||
|
||||
model:
|
||||
load_in_4bit: false # DialoGPT is smaller, quantization not needed
|
||||
bnb_4bit_compute_dtype: "bfloat16"
|
||||
bnb_4bit_use_double_quant: true
|
||||
device_map: "auto"
|
||||
gradient_checkpointing: false
|
||||
|
||||
progressive_stages:
|
||||
- name: "basic_cot"
|
||||
description: "Basic Chain-of-Thought reasoning"
|
||||
dataset_path: "./data/basic_cot/"
|
||||
adapter_config:
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
target_modules: ["c_attn", "c_proj"] # GPT-2 style attention modules
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 2
|
||||
per_device_batch_size: 4
|
||||
gradient_accumulation_steps: 4
|
||||
learning_rate: 2e-4
|
||||
warmup_steps: 100
|
||||
max_length: 1024
|
||||
fp16: false
|
||||
bf16: false # Use fp32 for smaller models
|
||||
max_grad_norm: 1.0
|
||||
weight_decay: 0.001
|
||||
|
||||
- name: "math_reasoning"
|
||||
description: "Mathematical reasoning with think tags"
|
||||
dataset_path: "./data/math_reasoning/"
|
||||
inherit_from: "basic_cot"
|
||||
adapter_config:
|
||||
r: 32
|
||||
lora_alpha: 64
|
||||
lora_dropout: 0.1
|
||||
target_modules: ["c_attn", "c_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 3
|
||||
per_device_batch_size: 2
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1e-4
|
||||
warmup_steps: 200
|
||||
max_length: 2048
|
||||
bf16: false
|
||||
max_grad_norm: 1.0
|
||||
|
||||
- name: "complex_reasoning"
|
||||
description: "Complex multi-step reasoning"
|
||||
dataset_path: "./data/complex_reasoning/"
|
||||
inherit_from: "math_reasoning"
|
||||
adapter_config:
|
||||
r: 64
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.1
|
||||
target_modules: ["c_attn", "c_proj"]
|
||||
init_lora_weights: true
|
||||
training:
|
||||
num_epochs: 2
|
||||
per_device_batch_size: 1
|
||||
gradient_accumulation_steps: 16
|
||||
learning_rate: 5e-5
|
||||
warmup_steps: 300
|
||||
max_length: 4096
|
||||
bf16: false
|
||||
max_grad_norm: 1.0
|
||||
|
||||
evaluation:
|
||||
benchmarks:
|
||||
- "HLE"
|
||||
- "Do-Not-Answer"
|
||||
save_results: true
|
||||
results_dir: "./outputs/evaluation_results"
|
||||
139
devenv.lock
Normal file
139
devenv.lock
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
{
|
||||
"nodes": {
|
||||
"devenv": {
|
||||
"locked": {
|
||||
"dir": "src/modules",
|
||||
"lastModified": 1751909516,
|
||||
"owner": "cachix",
|
||||
"repo": "devenv",
|
||||
"rev": "36e4cf7d6cb89862e69efce4e5c147ac2e4d38f9",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"dir": "src/modules",
|
||||
"owner": "cachix",
|
||||
"repo": "devenv",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-compat": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1747046372,
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-compat_2": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1747046372,
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"git-hooks": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat",
|
||||
"gitignore": "gitignore",
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1750779888,
|
||||
"owner": "cachix",
|
||||
"repo": "git-hooks.nix",
|
||||
"rev": "16ec914f6fb6f599ce988427d9d94efddf25fe6d",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "cachix",
|
||||
"repo": "git-hooks.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"gitignore": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"git-hooks",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1709087332,
|
||||
"owner": "hercules-ci",
|
||||
"repo": "gitignore.nix",
|
||||
"rev": "637db329424fd7e46cf4185293b9cc8c88c95394",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "hercules-ci",
|
||||
"repo": "gitignore.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1751792365,
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "1fd8bada0b6117e6c7eb54aad5813023eed37ccb",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs-python": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat_2",
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1749760516,
|
||||
"owner": "cachix",
|
||||
"repo": "nixpkgs-python",
|
||||
"rev": "908dbb466af5955ea479ac95953333fd64387216",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "cachix",
|
||||
"repo": "nixpkgs-python",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"devenv": "devenv",
|
||||
"git-hooks": "git-hooks",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-python": "nixpkgs-python",
|
||||
"pre-commit-hooks": [
|
||||
"git-hooks"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
||||
95
flake-minimal.nix
Normal file
95
flake-minimal.nix
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
{
|
||||
description = "Progressive LLM Training for 松尾研LLMコンペ2025 (Minimal)";
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
|
||||
outputs = { self, nixpkgs, flake-utils }:
|
||||
flake-utils.lib.eachDefaultSystem (system:
|
||||
let
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
config = {
|
||||
allowUnfree = true;
|
||||
cudaSupport = true;
|
||||
};
|
||||
};
|
||||
|
||||
# Python 3.11 for better compatibility
|
||||
python = pkgs.python311;
|
||||
|
||||
# Minimal Python packages
|
||||
pythonWithPackages = python.withPackages (ps: with ps; [
|
||||
# Core essentials only
|
||||
torch
|
||||
transformers
|
||||
numpy
|
||||
|
||||
# Essential dependencies
|
||||
pyyaml
|
||||
|
||||
# Build tools
|
||||
pip
|
||||
setuptools
|
||||
wheel
|
||||
]);
|
||||
|
||||
in
|
||||
{
|
||||
devShells.default = pkgs.mkShell {
|
||||
buildInputs = with pkgs; [
|
||||
# Python with packages
|
||||
pythonWithPackages
|
||||
|
||||
# Build tools
|
||||
gcc
|
||||
cmake
|
||||
ninja
|
||||
pkg-config
|
||||
|
||||
# Git
|
||||
git
|
||||
git-lfs
|
||||
|
||||
# Libraries needed for Python packages
|
||||
openssl
|
||||
zlib
|
||||
glib
|
||||
stdenv.cc.cc.lib
|
||||
|
||||
# CUDA support
|
||||
cudaPackages.cudatoolkit
|
||||
cudaPackages.cudnn
|
||||
];
|
||||
|
||||
shellHook = ''
|
||||
echo "🚀 Progressive LLM Training Environment (Minimal)"
|
||||
echo "Python version: $(python --version)"
|
||||
echo "PyTorch version: $(python -c 'import torch; print(torch.__version__)')"
|
||||
echo "CUDA available: $(python -c 'import torch; print(torch.cuda.is_available())')"
|
||||
|
||||
# Set up CUDA environment
|
||||
export CUDA_HOME=${pkgs.cudaPackages.cudatoolkit}
|
||||
export CUDA_PATH=${pkgs.cudaPackages.cudatoolkit}
|
||||
export LD_LIBRARY_PATH=${pkgs.cudaPackages.cudatoolkit}/lib:${pkgs.cudaPackages.cudnn}/lib:${pkgs.stdenv.cc.cc.lib}/lib:$LD_LIBRARY_PATH
|
||||
|
||||
# Set Python path
|
||||
export PYTHONPATH=$PWD/src:$PYTHONPATH
|
||||
|
||||
echo ""
|
||||
echo "Note: This is a minimal configuration. Install additional packages with pip as needed:"
|
||||
echo " pip install accelerate peft trl datasets bitsandbytes wandb jsonlines scikit-learn sentencepiece protobuf"
|
||||
echo " pip install flash-attn --no-build-isolation"
|
||||
'';
|
||||
|
||||
# Environment variables
|
||||
CUDA_HOME = "${pkgs.cudaPackages.cudatoolkit}";
|
||||
CUDA_PATH = "${pkgs.cudaPackages.cudatoolkit}";
|
||||
NIX_SHELL_PRESERVE_PROMPT = 1;
|
||||
LOCALE_ARCHIVE = "${pkgs.glibcLocales}/lib/locale/locale-archive";
|
||||
LC_ALL = "en_US.UTF-8";
|
||||
};
|
||||
});
|
||||
}
|
||||
61
flake.lock
generated
Normal file
61
flake.lock
generated
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
{
|
||||
"nodes": {
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1731533236,
|
||||
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1751792365,
|
||||
"narHash": "sha256-J1kI6oAj25IG4EdVlg2hQz8NZTBNYvIS0l4wpr9KcUo=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "1fd8bada0b6117e6c7eb54aad5813023eed37ccb",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils",
|
||||
"nixpkgs": "nixpkgs"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
||||
195
flake.nix
Normal file
195
flake.nix
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
{
|
||||
description = "Progressive LLM Training for 松尾研LLMコンペ2025";
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
|
||||
outputs = { self, nixpkgs, flake-utils }:
|
||||
flake-utils.lib.eachDefaultSystem (system:
|
||||
let
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
config = {
|
||||
allowUnfree = true;
|
||||
cudaSupport = true;
|
||||
};
|
||||
overlays = [
|
||||
(final: prev: {
|
||||
python311 = prev.python311.override {
|
||||
packageOverrides = python-self: python-super: {
|
||||
# Disable tests for problematic packages
|
||||
pytest-doctestplus = python-super.pytest-doctestplus.overrideAttrs (oldAttrs: {
|
||||
doCheck = false;
|
||||
doInstallCheck = false;
|
||||
pytestCheckPhase = "echo 'Skipping tests'";
|
||||
});
|
||||
# Also disable tests for jupyter-related packages if they cause issues
|
||||
jupyter = python-super.jupyter.overrideAttrs (oldAttrs: {
|
||||
doCheck = false;
|
||||
doInstallCheck = false;
|
||||
});
|
||||
notebook = python-super.notebook.overrideAttrs (oldAttrs: {
|
||||
doCheck = false;
|
||||
doInstallCheck = false;
|
||||
});
|
||||
# Disable tests for psycopg and psycopg2
|
||||
psycopg = python-super.psycopg.overrideAttrs (oldAttrs: {
|
||||
doCheck = false;
|
||||
doInstallCheck = false;
|
||||
pytestCheckPhase = "echo 'Skipping tests'";
|
||||
pythonImportsCheck = []; # Disable import checks
|
||||
});
|
||||
psycopg2 = python-super.psycopg2.overrideAttrs (oldAttrs: {
|
||||
doCheck = false;
|
||||
doInstallCheck = false;
|
||||
pytestCheckPhase = "echo 'Skipping tests'";
|
||||
pythonImportsCheck = []; # Disable import checks
|
||||
});
|
||||
# Disable tests for sqlframe
|
||||
sqlframe = python-super.sqlframe.overrideAttrs (oldAttrs: {
|
||||
doCheck = false;
|
||||
doInstallCheck = false;
|
||||
pytestCheckPhase = "echo 'Skipping tests'";
|
||||
pythonImportsCheck = []; # Disable import checks
|
||||
});
|
||||
# Disable tests for accelerate
|
||||
accelerate = python-super.accelerate.overrideAttrs (oldAttrs: {
|
||||
doCheck = false;
|
||||
doInstallCheck = false;
|
||||
pytestCheckPhase = "echo 'Skipping tests'";
|
||||
pythonImportsCheck = []; # Disable import checks
|
||||
});
|
||||
};
|
||||
};
|
||||
})
|
||||
];
|
||||
};
|
||||
|
||||
# Python 3.11 for better compatibility
|
||||
python = pkgs.python311;
|
||||
|
||||
# Python packages
|
||||
pythonWithPackages = python.withPackages (ps: with ps; [
|
||||
# Core ML packages
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
transformers
|
||||
accelerate
|
||||
datasets
|
||||
tokenizers
|
||||
scikit-learn
|
||||
|
||||
# Required dependencies from requirements.txt
|
||||
pyyaml
|
||||
jsonlines
|
||||
sentencepiece
|
||||
protobuf
|
||||
|
||||
# Additional useful packages
|
||||
numpy
|
||||
scipy
|
||||
matplotlib
|
||||
jupyter
|
||||
notebook
|
||||
ipython
|
||||
pandas
|
||||
rich # For TUI
|
||||
|
||||
# Development tools
|
||||
black
|
||||
flake8
|
||||
pytest
|
||||
mypy
|
||||
|
||||
# Build tools
|
||||
pip
|
||||
setuptools
|
||||
wheel
|
||||
|
||||
# LLM specific packages
|
||||
peft
|
||||
trl
|
||||
bitsandbytes
|
||||
wandb
|
||||
]);
|
||||
|
||||
in
|
||||
{
|
||||
devShells.default = pkgs.mkShell {
|
||||
buildInputs = with pkgs; [
|
||||
# Python with packages
|
||||
pythonWithPackages
|
||||
|
||||
# Build tools
|
||||
gcc
|
||||
cmake
|
||||
ninja
|
||||
pkg-config
|
||||
|
||||
# Git
|
||||
git
|
||||
git-lfs
|
||||
|
||||
# Development tools
|
||||
htop
|
||||
tmux
|
||||
vim
|
||||
|
||||
# Libraries needed for Python packages
|
||||
openssl
|
||||
zlib
|
||||
glib
|
||||
stdenv.cc.cc.lib
|
||||
|
||||
# CUDA support
|
||||
cudaPackages.cudatoolkit
|
||||
cudaPackages.cudnn
|
||||
];
|
||||
|
||||
shellHook = ''
|
||||
echo "🚀 Progressive LLM Training Environment"
|
||||
echo "Python version: $(python --version)"
|
||||
echo "PyTorch version: $(python -c 'import torch; print(torch.__version__)')"
|
||||
echo "CUDA available: $(python -c 'import torch; print(torch.cuda.is_available())')"
|
||||
|
||||
# Set up CUDA environment
|
||||
export CUDA_HOME=${pkgs.cudaPackages.cudatoolkit}
|
||||
export CUDA_PATH=${pkgs.cudaPackages.cudatoolkit}
|
||||
export LD_LIBRARY_PATH=${pkgs.cudaPackages.cudatoolkit}/lib:${pkgs.cudaPackages.cudnn}/lib:${pkgs.stdenv.cc.cc.lib}/lib:$LD_LIBRARY_PATH
|
||||
|
||||
# Set Python path
|
||||
export PYTHONPATH=$PWD/src:$PYTHONPATH
|
||||
|
||||
echo ""
|
||||
echo "Available commands:"
|
||||
echo " python scripts/train_progressive.py # Start training"
|
||||
echo " python scripts/evaluate.py # Evaluate model"
|
||||
echo " jupyter notebook # Start Jupyter"
|
||||
echo ""
|
||||
|
||||
# Create data directory if not exists
|
||||
mkdir -p data
|
||||
|
||||
# Prepare sample data if not exists
|
||||
if [ ! -f "data/basic_cot/train.jsonl" ]; then
|
||||
echo "Preparing sample datasets..."
|
||||
python -c "from src.data_utils import prepare_sample_datasets; prepare_sample_datasets()" || echo "Sample data preparation skipped"
|
||||
fi
|
||||
|
||||
# Note about flash-attn
|
||||
echo "Note: flash-attn is not included in nixpkgs. If needed, install manually with:"
|
||||
echo " pip install flash-attn --no-build-isolation"
|
||||
'';
|
||||
|
||||
# Environment variables
|
||||
CUDA_HOME = "${pkgs.cudaPackages.cudatoolkit}";
|
||||
CUDA_PATH = "${pkgs.cudaPackages.cudatoolkit}";
|
||||
NIX_SHELL_PRESERVE_PROMPT = 1;
|
||||
LOCALE_ARCHIVE = "${pkgs.glibcLocales}/lib/locale/locale-archive";
|
||||
LC_ALL = "en_US.UTF-8";
|
||||
};
|
||||
});
|
||||
}
|
||||
15
requirements-cpu.txt
Normal file
15
requirements-cpu.txt
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# CPU version of PyTorch
|
||||
torch>=2.0.0 --index-url https://download.pytorch.org/whl/cpu
|
||||
transformers>=4.40.0
|
||||
accelerate>=0.27.0
|
||||
peft>=0.11.0
|
||||
trl>=0.9.0
|
||||
datasets>=2.18.0
|
||||
bitsandbytes>=0.43.0
|
||||
wandb>=0.16.0
|
||||
pyyaml>=6.0
|
||||
jsonlines>=4.0.0
|
||||
scikit-learn>=1.3.0
|
||||
# flash-attn is not needed for CPU version
|
||||
sentencepiece>=0.2.0
|
||||
protobuf>=4.25.0
|
||||
3
requirements-torch.txt
Normal file
3
requirements-torch.txt
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
--index-url https://download.pytorch.org/whl/cu128
|
||||
torch>=2.0.0
|
||||
torchaudio>=2.0.0
|
||||
13
requirements.txt
Normal file
13
requirements.txt
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
transformers>=4.40.0
|
||||
accelerate>=0.27.0
|
||||
peft>=0.11.0
|
||||
trl>=0.9.0
|
||||
datasets>=2.18.0
|
||||
bitsandbytes>=0.43.0
|
||||
wandb>=0.16.0
|
||||
pyyaml>=6.0
|
||||
jsonlines>=4.0.0
|
||||
scikit-learn>=1.3.0
|
||||
# flash-attn>=2.5.0 # Install separately with --no-build-isolation
|
||||
sentencepiece>=0.2.0
|
||||
protobuf>=4.25.0
|
||||
137
scripts/analyze_adapter_size.py
Executable file
137
scripts/analyze_adapter_size.py
Executable file
|
|
@ -0,0 +1,137 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Analyze the size and structure of LoRA adapters
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import yaml
|
||||
from peft import PeftModel, LoraConfig
|
||||
|
||||
# Add src to path
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from src.progressive_model import ProgressiveReasoningModel
|
||||
|
||||
|
||||
def analyze_adapter_sizes():
|
||||
# Load configuration
|
||||
with open("config/training_config.yaml") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
print("=" * 60)
|
||||
print("LoRA Adapter Size Analysis")
|
||||
print("=" * 60)
|
||||
|
||||
# Get adapter configuration from config
|
||||
basic_cot_config = config["progressive_stages"][0]
|
||||
adapter_config = basic_cot_config["adapter_config"]
|
||||
|
||||
print(f"\nConfiguration for 'basic_cot' adapter:")
|
||||
print(f" - r (rank): {adapter_config['r']}")
|
||||
print(f" - lora_alpha: {adapter_config['lora_alpha']}")
|
||||
print(f" - lora_dropout: {adapter_config['lora_dropout']}")
|
||||
print(f" - target_modules: {adapter_config['target_modules']}")
|
||||
|
||||
# Load the base model to get dimensions
|
||||
print("\nLoading base model to analyze dimensions...")
|
||||
model_wrapper = ProgressiveReasoningModel(config)
|
||||
model_wrapper.setup_base_model()
|
||||
|
||||
# Analyze model architecture
|
||||
print(f"\nBase model: {config['experiment']['base_model']}")
|
||||
|
||||
# Count parameters in base model
|
||||
total_params = sum(p.numel() for p in model_wrapper.model.parameters())
|
||||
print(f"Total base model parameters: {total_params:,}")
|
||||
|
||||
# Load saved adapter if it exists
|
||||
adapter_path = Path(config["experiment"]["output_dir"]) / "adapters" / "basic_cot"
|
||||
if adapter_path.exists():
|
||||
print(f"\nLoading saved adapter from: {adapter_path}")
|
||||
|
||||
# Load adapter state dict
|
||||
adapter_model_path = adapter_path / "adapter_model.safetensors"
|
||||
if not adapter_model_path.exists():
|
||||
adapter_model_path = adapter_path / "adapter_model.bin"
|
||||
|
||||
if adapter_model_path.exists():
|
||||
if adapter_model_path.suffix == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
adapter_weights = load_file(adapter_model_path)
|
||||
else:
|
||||
adapter_weights = torch.load(adapter_model_path, map_location="cpu")
|
||||
|
||||
print("\nLoRA Adapter Layer Details:")
|
||||
print("-" * 60)
|
||||
|
||||
total_lora_params = 0
|
||||
layer_info = {}
|
||||
|
||||
for name, tensor in adapter_weights.items():
|
||||
size = tensor.numel()
|
||||
total_lora_params += size
|
||||
|
||||
# Parse layer name
|
||||
parts = name.split('.')
|
||||
if 'lora_A' in name or 'lora_B' in name:
|
||||
# Extract module info
|
||||
module_name = '.'.join(parts[:-2])
|
||||
lora_type = parts[-2] # lora_A or lora_B
|
||||
|
||||
if module_name not in layer_info:
|
||||
layer_info[module_name] = {}
|
||||
|
||||
layer_info[module_name][lora_type] = {
|
||||
'shape': list(tensor.shape),
|
||||
'params': size
|
||||
}
|
||||
|
||||
# Display layer information
|
||||
for module, info in sorted(layer_info.items()):
|
||||
print(f"\nModule: {module}")
|
||||
if 'lora_A' in info and 'lora_B' in info:
|
||||
shape_a = info['lora_A']['shape']
|
||||
shape_b = info['lora_B']['shape']
|
||||
params_a = info['lora_A']['params']
|
||||
params_b = info['lora_B']['params']
|
||||
|
||||
print(f" LoRA A: {shape_a} = {params_a:,} parameters")
|
||||
print(f" LoRA B: {shape_b} = {params_b:,} parameters")
|
||||
print(f" Total: {params_a + params_b:,} parameters")
|
||||
|
||||
# Calculate original layer size (approximation)
|
||||
original_size = shape_a[1] * shape_b[0]
|
||||
compression_ratio = original_size / (params_a + params_b)
|
||||
print(f" Original layer size (approx): {original_size:,} parameters")
|
||||
print(f" Compression ratio: {compression_ratio:.1f}x")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"Total LoRA parameters: {total_lora_params:,}")
|
||||
print(f"Percentage of base model: {(total_lora_params / total_params) * 100:.2f}%")
|
||||
|
||||
# Calculate theoretical size
|
||||
r = adapter_config['r']
|
||||
num_modules = len(adapter_config['target_modules'])
|
||||
|
||||
# For GPT models, typical dimensions
|
||||
if "DialoGPT" in config['experiment']['base_model']:
|
||||
hidden_size = 768 # DialoGPT-small uses 768
|
||||
print(f"\nTheoretical calculation (hidden_size={hidden_size}, r={r}):")
|
||||
print(f" Per module: 2 * {hidden_size} * {r} = {2 * hidden_size * r:,} parameters")
|
||||
print(f" Total ({num_modules} modules): {2 * hidden_size * r * num_modules:,} parameters")
|
||||
else:
|
||||
print(f"\nNo saved adapter found at: {adapter_path}")
|
||||
print("Run training first to generate the adapter.")
|
||||
|
||||
# Show theoretical sizes based on config
|
||||
r = adapter_config['r']
|
||||
print(f"\nTheoretical LoRA sizes with r={r}:")
|
||||
print(f" For hidden_size=768 (DialoGPT-small): {2 * 768 * r:,} params per module")
|
||||
print(f" For hidden_size=1024 (medium models): {2 * 1024 * r:,} params per module")
|
||||
print(f" For hidden_size=1280 (GPT-2 large): {2 * 1280 * r:,} params per module")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
analyze_adapter_sizes()
|
||||
199
scripts/check_vram.py
Normal file
199
scripts/check_vram.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Check VRAM usage and model memory requirements
|
||||
"""
|
||||
|
||||
import torch
|
||||
import psutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
# Add src to path
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
|
||||
|
||||
def get_memory_info():
|
||||
"""Get current memory usage"""
|
||||
if torch.cuda.is_available():
|
||||
print("=== CUDA Information ===")
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
||||
print(f"CUDA device count: {torch.cuda.device_count()}")
|
||||
|
||||
# Get VRAM info
|
||||
vram_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||
vram_reserved = torch.cuda.memory_reserved(0) / 1024**3
|
||||
vram_allocated = torch.cuda.memory_allocated(0) / 1024**3
|
||||
vram_free = vram_total - vram_allocated
|
||||
|
||||
print(f"\n=== VRAM Usage ===")
|
||||
print(f"Total VRAM: {vram_total:.2f} GB")
|
||||
print(f"Allocated VRAM: {vram_allocated:.2f} GB")
|
||||
print(f"Reserved VRAM: {vram_reserved:.2f} GB")
|
||||
print(f"Free VRAM: {vram_free:.2f} GB")
|
||||
else:
|
||||
print("CUDA not available!")
|
||||
|
||||
# Get system RAM info
|
||||
ram = psutil.virtual_memory()
|
||||
print(f"\n=== System RAM ===")
|
||||
print(f"Total RAM: {ram.total / 1024**3:.2f} GB")
|
||||
print(f"Available RAM: {ram.available / 1024**3:.2f} GB")
|
||||
print(f"Used RAM: {ram.used / 1024**3:.2f} GB ({ram.percent}%)")
|
||||
|
||||
|
||||
def estimate_model_size(model_name: str, quantization: str = None):
|
||||
"""Estimate model memory requirements"""
|
||||
print(f"\n=== Model Memory Estimation ===")
|
||||
print(f"Model: {model_name}")
|
||||
|
||||
# Common model sizes (in billions of parameters)
|
||||
model_sizes = {
|
||||
"gemma-2-2b": 2.5,
|
||||
"gemma-3-1b": 1.2,
|
||||
"llama-3.2-8b": 8,
|
||||
"llama-3.2-13b": 13,
|
||||
"llama-3.2-70b": 70,
|
||||
}
|
||||
|
||||
# Find model size
|
||||
model_key = None
|
||||
for key in model_sizes:
|
||||
if key in model_name.lower():
|
||||
model_key = key
|
||||
break
|
||||
|
||||
if model_key:
|
||||
params_billions = model_sizes[model_key]
|
||||
|
||||
# Memory estimates (rough)
|
||||
fp32_gb = params_billions * 4 # 4 bytes per parameter
|
||||
fp16_gb = params_billions * 2 # 2 bytes per parameter
|
||||
int8_gb = params_billions * 1 # 1 byte per parameter
|
||||
int4_gb = params_billions * 0.5 # 0.5 bytes per parameter
|
||||
|
||||
print(f"Estimated parameters: {params_billions}B")
|
||||
print(f"Memory requirements:")
|
||||
print(f" FP32: ~{fp32_gb:.1f} GB")
|
||||
print(f" FP16/BF16: ~{fp16_gb:.1f} GB")
|
||||
print(f" INT8: ~{int8_gb:.1f} GB")
|
||||
print(f" INT4 (QLoRA): ~{int4_gb:.1f} GB")
|
||||
|
||||
# Add overhead for activations and gradients
|
||||
print(f"\nWith training overhead:")
|
||||
print(f" FP16 + LoRA: ~{fp16_gb * 1.5:.1f} GB")
|
||||
print(f" INT4 + QLoRA: ~{int4_gb * 1.5:.1f} GB")
|
||||
else:
|
||||
print("Model size not recognized, unable to estimate memory requirements")
|
||||
|
||||
|
||||
def suggest_offloading_strategies():
|
||||
"""Suggest CPU offloading strategies"""
|
||||
print("\n=== CPU Offloading Strategies ===")
|
||||
print("\n1. **Device Map Auto with CPU Offload**")
|
||||
print(" ```python")
|
||||
print(" device_map = {")
|
||||
print(" 'model.embed_tokens': 'cpu',")
|
||||
print(" 'model.layers.0': 0, # GPU")
|
||||
print(" 'model.layers.1': 0, # GPU")
|
||||
print(" 'model.layers.2': 'cpu', # CPU")
|
||||
print(" # ... distribute layers between GPU and CPU")
|
||||
print(" }")
|
||||
print(" ```")
|
||||
|
||||
print("\n2. **Accelerate's CPU Offload**")
|
||||
print(" ```yaml")
|
||||
print(" model:")
|
||||
print(" device_map: 'auto'")
|
||||
print(" max_memory:")
|
||||
print(" 0: '4GB' # Limit GPU memory")
|
||||
print(" 'cpu': '20GB' # Allow CPU memory")
|
||||
print(" ```")
|
||||
|
||||
print("\n3. **DeepSpeed ZeRO-Offload**")
|
||||
print(" - ZeRO-2: Offload optimizer states to CPU")
|
||||
print(" - ZeRO-3: Offload optimizer states and parameters to CPU")
|
||||
print(" ```yaml")
|
||||
print(" deepspeed:")
|
||||
print(" zero_optimization:")
|
||||
print(" stage: 2")
|
||||
print(" offload_optimizer:")
|
||||
print(" device: 'cpu'")
|
||||
print(" ```")
|
||||
|
||||
print("\n4. **Gradient Checkpointing + CPU Offload**")
|
||||
print(" - Trade compute for memory")
|
||||
print(" - Combine with layer-wise CPU offloading")
|
||||
|
||||
print("\n5. **QLoRA with CPU Offload**")
|
||||
print(" - 4-bit quantization reduces base model size")
|
||||
print(" - Only LoRA parameters on GPU")
|
||||
print(" - Base model layers can be on CPU")
|
||||
|
||||
|
||||
def check_config_compatibility(config_path: str):
|
||||
"""Check if config is compatible with CPU offloading"""
|
||||
if Path(config_path).exists():
|
||||
with open(config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
print(f"\n=== Config Analysis: {config_path} ===")
|
||||
model_config = config.get("model", {})
|
||||
|
||||
print(f"Current settings:")
|
||||
print(f" 4-bit quantization: {model_config.get('load_in_4bit', False)}")
|
||||
print(f" Gradient checkpointing: {model_config.get('gradient_checkpointing', False)}")
|
||||
print(f" Device map: {model_config.get('device_map', 'None')}")
|
||||
|
||||
if model_config.get('load_in_4bit', False):
|
||||
print("✓ Already using 4-bit quantization (good for memory)")
|
||||
else:
|
||||
print("✗ Consider enabling 4-bit quantization")
|
||||
|
||||
if not model_config.get('gradient_checkpointing', False):
|
||||
print("✗ Consider enabling gradient checkpointing")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
print("VRAM and Memory Analysis for Progressive LLM Training")
|
||||
print("=" * 60)
|
||||
|
||||
# Get memory info
|
||||
get_memory_info()
|
||||
|
||||
# Estimate model sizes
|
||||
models = [
|
||||
"google/gemma-2-2b-it",
|
||||
"google/gemma-3-1b-pt",
|
||||
"meta-llama/Llama-3.2-8B",
|
||||
]
|
||||
|
||||
for model in models:
|
||||
estimate_model_size(model)
|
||||
|
||||
# Suggest strategies
|
||||
suggest_offloading_strategies()
|
||||
|
||||
# Check configs
|
||||
configs = [
|
||||
"config/training_config_gemma3_1b.yaml",
|
||||
"config/training_config_gemma2_small.yaml",
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
check_config_compatibility(config)
|
||||
|
||||
print("\n=== Recommendations ===")
|
||||
print("1. Start with QLoRA (4-bit) if not already enabled")
|
||||
print("2. Use device_map with max_memory limits")
|
||||
print("3. Enable gradient checkpointing")
|
||||
print("4. Consider DeepSpeed for advanced offloading")
|
||||
print("5. Monitor actual usage during training")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
183
scripts/compare_models_tui.py
Executable file
183
scripts/compare_models_tui.py
Executable file
|
|
@ -0,0 +1,183 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
TUI for comparing original and trained models
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
import torch
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.columns import Columns
|
||||
from rich.prompt import Prompt
|
||||
from rich.text import Text
|
||||
from rich.layout import Layout
|
||||
from rich.live import Live
|
||||
from rich.table import Table
|
||||
import time
|
||||
|
||||
# Add src to path
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from src.progressive_model import ProgressiveReasoningModel
|
||||
|
||||
|
||||
class ModelCompareTUI:
|
||||
def __init__(self, config_path: str = "config/training_config.yaml"):
|
||||
self.console = Console()
|
||||
|
||||
# Load configuration
|
||||
with open(config_path) as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
|
||||
# Initialize models
|
||||
self.console.print("[yellow]Loading models...[/yellow]")
|
||||
|
||||
# Original model
|
||||
self.original_model = ProgressiveReasoningModel(self.config)
|
||||
self.original_model.setup_base_model()
|
||||
|
||||
# Trained model
|
||||
self.trained_model = ProgressiveReasoningModel(self.config)
|
||||
self.trained_model.setup_base_model()
|
||||
|
||||
# Load the trained adapter if it exists
|
||||
adapter_path = Path(self.config["experiment"]["output_dir"]) / "adapters" / "basic_cot"
|
||||
if adapter_path.exists():
|
||||
self.console.print(f"[green]Loading trained adapter from: {adapter_path}[/green]")
|
||||
self.trained_model.load_for_inference(["basic_cot"])
|
||||
else:
|
||||
self.console.print("[red]No trained adapter found. Please run training first.[/red]")
|
||||
self.console.print("[yellow]Both models will show original behavior.[/yellow]")
|
||||
|
||||
self.console.print("[green]Models loaded successfully![/green]\n")
|
||||
|
||||
def generate_response(self, model, prompt: str, with_think_tags: bool = True) -> str:
|
||||
"""Generate response from a model"""
|
||||
# For trained model, encourage think tags
|
||||
if with_think_tags and model == self.trained_model:
|
||||
formatted_prompt = f"{prompt}\n\nPlease think step by step."
|
||||
else:
|
||||
formatted_prompt = prompt
|
||||
|
||||
inputs = model.tokenizer(formatted_prompt, return_tensors="pt").to(model.model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.model.generate(
|
||||
**inputs,
|
||||
max_length=512,
|
||||
temperature=0.7,
|
||||
do_sample=True,
|
||||
top_p=0.95,
|
||||
pad_token_id=model.tokenizer.pad_token_id,
|
||||
eos_token_id=model.tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
response = model.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
# Extract response after prompt
|
||||
response = response[len(formatted_prompt):].strip()
|
||||
|
||||
return response
|
||||
|
||||
def create_comparison_panel(self, prompt: str, original_response: str, trained_response: str) -> Panel:
|
||||
"""Create a panel showing the comparison"""
|
||||
# Create table
|
||||
table = Table(show_header=True, header_style="bold magenta", expand=True)
|
||||
table.add_column("Original Model", style="cyan", width=50)
|
||||
table.add_column("Trained Model (with CoT)", style="green", width=50)
|
||||
|
||||
table.add_row(original_response, trained_response)
|
||||
|
||||
return Panel(
|
||||
table,
|
||||
title=f"[bold yellow]Prompt: {prompt}[/bold yellow]",
|
||||
border_style="blue"
|
||||
)
|
||||
|
||||
def run_interactive_mode(self):
|
||||
"""Run interactive comparison mode"""
|
||||
self.console.print("\n[bold cyan]Model Comparison TUI[/bold cyan]")
|
||||
self.console.print("Compare responses from original and trained models\n")
|
||||
self.console.print("[dim]Type 'quit' or 'exit' to leave[/dim]\n")
|
||||
|
||||
while True:
|
||||
# Get user prompt
|
||||
prompt = Prompt.ask("\n[bold yellow]Enter your prompt[/bold yellow]")
|
||||
|
||||
if prompt.lower() in ['quit', 'exit']:
|
||||
self.console.print("\n[yellow]Goodbye![/yellow]")
|
||||
break
|
||||
|
||||
# Generate responses
|
||||
self.console.print("\n[dim]Generating responses...[/dim]")
|
||||
|
||||
start_time = time.time()
|
||||
original_response = self.generate_response(self.original_model, prompt, with_think_tags=False)
|
||||
original_time = time.time() - start_time
|
||||
|
||||
start_time = time.time()
|
||||
trained_response = self.generate_response(self.trained_model, prompt, with_think_tags=True)
|
||||
trained_time = time.time() - start_time
|
||||
|
||||
# Display comparison
|
||||
panel = self.create_comparison_panel(prompt, original_response, trained_response)
|
||||
self.console.print(panel)
|
||||
|
||||
# Show generation times
|
||||
self.console.print(f"\n[dim]Generation times - Original: {original_time:.2f}s, Trained: {trained_time:.2f}s[/dim]")
|
||||
|
||||
def run_benchmark_mode(self):
|
||||
"""Run benchmark with predefined prompts"""
|
||||
test_prompts = [
|
||||
"What is 156 + 389?",
|
||||
"If I have 23 apples and buy 17 more, how many do I have?",
|
||||
"A store has 145 items. If 38 are sold, how many remain?",
|
||||
"What is 45 * 12?",
|
||||
"Explain why 2 + 2 = 4",
|
||||
"If a train travels 80 km/h for 2.5 hours, how far does it go?",
|
||||
"What is the sum of all numbers from 1 to 10?",
|
||||
"How many minutes are in 3.5 hours?",
|
||||
]
|
||||
|
||||
self.console.print("\n[bold cyan]Running Benchmark Comparison[/bold cyan]\n")
|
||||
|
||||
for i, prompt in enumerate(test_prompts, 1):
|
||||
self.console.print(f"[bold]Test {i}/{len(test_prompts)}[/bold]")
|
||||
|
||||
# Generate responses
|
||||
original_response = self.generate_response(self.original_model, prompt, with_think_tags=False)
|
||||
trained_response = self.generate_response(self.trained_model, prompt, with_think_tags=True)
|
||||
|
||||
# Display comparison
|
||||
panel = self.create_comparison_panel(prompt, original_response, trained_response)
|
||||
self.console.print(panel)
|
||||
self.console.print("")
|
||||
|
||||
self.console.print("[green]Benchmark completed![/green]")
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Compare original and trained models")
|
||||
parser.add_argument("--mode", choices=["interactive", "benchmark"], default="interactive",
|
||||
help="Mode to run the comparison")
|
||||
parser.add_argument("--config", default="config/training_config.yaml",
|
||||
help="Path to configuration file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create TUI
|
||||
tui = ModelCompareTUI(args.config)
|
||||
|
||||
# Run in selected mode
|
||||
if args.mode == "interactive":
|
||||
tui.run_interactive_mode()
|
||||
else:
|
||||
tui.run_benchmark_mode()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
59
scripts/evaluate.py
Executable file
59
scripts/evaluate.py
Executable file
|
|
@ -0,0 +1,59 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Evaluation script for progressive model
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from src.progressive_model import ProgressiveReasoningModel
|
||||
import yaml
|
||||
|
||||
|
||||
def evaluate_reasoning(model_wrapper, test_prompts):
|
||||
"""Evaluate model on test prompts"""
|
||||
results = []
|
||||
|
||||
for prompt in test_prompts:
|
||||
print(f"\nPrompt: {prompt}")
|
||||
response = model_wrapper.generate_with_reasoning(prompt)
|
||||
print(f"Response: {response}")
|
||||
results.append({
|
||||
"prompt": prompt,
|
||||
"response": response
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
# Load config
|
||||
with open("config/training_config.yaml") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# Initialize model
|
||||
model_wrapper = ProgressiveReasoningModel(config)
|
||||
model_wrapper.setup_base_model()
|
||||
|
||||
# Test different adapters
|
||||
test_prompts = [
|
||||
"What is 156 + 389?",
|
||||
"If a train travels 80 km/h for 2.5 hours, how far does it go?",
|
||||
"Explain why the sky is blue.",
|
||||
]
|
||||
|
||||
# Test each adapter
|
||||
for adapter_name in ["basic_cot", "math_reasoning", "complex_reasoning"]:
|
||||
if adapter_name in model_wrapper.adapters:
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Testing adapter: {adapter_name}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
model_wrapper.load_for_inference([adapter_name])
|
||||
results = evaluate_reasoning(model_wrapper, test_prompts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
189
scripts/simple_compare.py
Executable file
189
scripts/simple_compare.py
Executable file
|
|
@ -0,0 +1,189 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple comparison script without rich TUI
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
# Add src to path
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from src.progressive_model import ProgressiveReasoningModel
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Compare original and trained models")
|
||||
parser.add_argument(
|
||||
"--config", "-c",
|
||||
type=str,
|
||||
default="config/training_config_gemma2_small.yaml",
|
||||
help="Path to configuration file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter", "-a",
|
||||
type=str,
|
||||
default="basic_cot",
|
||||
help="Adapter name to load for comparison"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-length",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Maximum generation length"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_config(config_path):
|
||||
"""Load configuration from file"""
|
||||
config_path = Path(config_path)
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
||||
|
||||
with open(config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
return config
|
||||
|
||||
|
||||
def generate_response(model, tokenizer, prompt, max_length=512):
|
||||
"""Generate response using the model"""
|
||||
# Format prompt for Gemma
|
||||
formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
|
||||
|
||||
# Tokenize
|
||||
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_length=len(inputs["input_ids"][0]) + max_length,
|
||||
temperature=0.7,
|
||||
do_sample=True,
|
||||
top_p=0.9,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
repetition_penalty=1.1,
|
||||
)
|
||||
|
||||
# Decode
|
||||
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
# Extract only the model's response
|
||||
if "<start_of_turn>model" in response:
|
||||
response = response.split("<start_of_turn>model")[-1].strip()
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
try:
|
||||
config = load_config(args.config)
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: {e}")
|
||||
return
|
||||
|
||||
print(f"Progressive Model Comparison")
|
||||
print(f"Config: {args.config}")
|
||||
print(f"Base model: {config['experiment']['base_model']}")
|
||||
print(f"Adapter: {args.adapter}")
|
||||
print("="*60)
|
||||
|
||||
print("Loading models...")
|
||||
|
||||
# Original model (no adapter)
|
||||
print("Loading original model...")
|
||||
original_model = ProgressiveReasoningModel(config)
|
||||
original_model.setup_base_model()
|
||||
|
||||
# Trained model (with adapter)
|
||||
print("Loading trained model...")
|
||||
trained_model = ProgressiveReasoningModel(config)
|
||||
trained_model.setup_base_model()
|
||||
|
||||
# Load the trained adapter if it exists
|
||||
adapter_path = Path(config["experiment"]["output_dir"]) / "adapters" / args.adapter
|
||||
if adapter_path.exists():
|
||||
print(f"Loading trained adapter from: {adapter_path}")
|
||||
try:
|
||||
trained_model.load_for_inference([args.adapter])
|
||||
print("Adapter loaded successfully!")
|
||||
except Exception as e:
|
||||
print(f"Error loading adapter: {e}")
|
||||
print("Will compare with base model instead.")
|
||||
else:
|
||||
print(f"No trained adapter found at: {adapter_path}")
|
||||
print("Available adapters:")
|
||||
adapters_dir = Path(config["experiment"]["output_dir"]) / "adapters"
|
||||
if adapters_dir.exists():
|
||||
for adapter_dir in adapters_dir.iterdir():
|
||||
if adapter_dir.is_dir():
|
||||
print(f" - {adapter_dir.name}")
|
||||
else:
|
||||
print(" No adapters directory found.")
|
||||
print("Both models will show original behavior.")
|
||||
|
||||
print("\nModels loaded! Enter prompts to compare (type 'quit' to exit)")
|
||||
print("Examples:")
|
||||
print(" - What is 25 + 17?")
|
||||
print(" - Explain why the sky is blue")
|
||||
print(" - Solve this step by step: If I have 10 apples and give away 3, how many do I have left?")
|
||||
print()
|
||||
|
||||
while True:
|
||||
try:
|
||||
prompt = input("\nPrompt: ").strip()
|
||||
if prompt.lower() in ['quit', 'exit', 'q']:
|
||||
break
|
||||
|
||||
if not prompt:
|
||||
continue
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("ORIGINAL MODEL (No fine-tuning)")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
original_response = generate_response(
|
||||
original_model.model,
|
||||
original_model.tokenizer,
|
||||
prompt,
|
||||
args.max_length
|
||||
)
|
||||
print(original_response)
|
||||
except Exception as e:
|
||||
print(f"Error generating original response: {e}")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"TRAINED MODEL (With {args.adapter} adapter)")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Add CoT prompt for trained model
|
||||
cot_prompt = f"{prompt}\n\nPlease think step by step using <think> tags."
|
||||
trained_response = generate_response(
|
||||
trained_model.model,
|
||||
trained_model.tokenizer,
|
||||
cot_prompt,
|
||||
args.max_length
|
||||
)
|
||||
print(trained_response)
|
||||
except Exception as e:
|
||||
print(f"Error generating trained response: {e}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting...")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
continue
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
131
scripts/train_progressive.py
Executable file
131
scripts/train_progressive.py
Executable file
|
|
@ -0,0 +1,131 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Main training script for progressive reasoning model
|
||||
"""
|
||||
|
||||
import sys
|
||||
import yaml
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
# Add src to path
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from src.progressive_model import ProgressiveReasoningModel
|
||||
from src.training import ProgressiveTrainer
|
||||
from src.data_utils import prepare_sample_datasets
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Progressive LLM Training for 松尾研LLMコンペ2025",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Use default config
|
||||
python scripts/train_progressive.py
|
||||
|
||||
# Use specific config file
|
||||
python scripts/train_progressive.py --config config/training_config_large.yaml
|
||||
|
||||
# Use config with custom path
|
||||
python scripts/train_progressive.py --config /path/to/my_config.yaml
|
||||
|
||||
# Prepare sample datasets
|
||||
python scripts/train_progressive.py --prepare-data
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--config", "-c",
|
||||
type=str,
|
||||
default="config/training_config.yaml",
|
||||
help="Path to the training configuration file (default: config/training_config.yaml)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prepare-data",
|
||||
action="store_true",
|
||||
help="Prepare sample datasets before training"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Load config and model but skip training (for testing)"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_config(config_path: str) -> dict:
|
||||
"""Load configuration from file"""
|
||||
config_path = Path(config_path)
|
||||
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
||||
|
||||
print(f"Loading configuration from: {config_path}")
|
||||
|
||||
with open(config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
print("Progressive LLM Training for 松尾研LLMコンペ2025")
|
||||
print("=" * 50)
|
||||
|
||||
# Load configuration
|
||||
try:
|
||||
config = load_config(args.config)
|
||||
except FileNotFoundError as e:
|
||||
print(f"Error: {e}")
|
||||
print("Available config files:")
|
||||
config_dir = Path("config")
|
||||
if config_dir.exists():
|
||||
for config_file in config_dir.glob("*.yaml"):
|
||||
print(f" {config_file}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"Error loading config: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Print configuration info
|
||||
print(f"Experiment: {config['experiment']['name']}")
|
||||
print(f"Base model: {config['experiment']['base_model']}")
|
||||
print(f"Output directory: {config['experiment']['output_dir']}")
|
||||
print(f"Stages: {len(config['progressive_stages'])}")
|
||||
|
||||
# Prepare sample datasets if requested
|
||||
if args.prepare_data:
|
||||
print("\nPreparing sample datasets...")
|
||||
prepare_sample_datasets()
|
||||
print("Sample datasets prepared.")
|
||||
|
||||
# Initialize model wrapper
|
||||
print("\nInitializing model...")
|
||||
model_wrapper = ProgressiveReasoningModel(config)
|
||||
model_wrapper.setup_base_model()
|
||||
|
||||
if args.dry_run:
|
||||
print("\nDry run completed. Model loaded successfully.")
|
||||
return
|
||||
|
||||
# Initialize trainer
|
||||
print("\nInitializing trainer...")
|
||||
trainer = ProgressiveTrainer(model_wrapper, config)
|
||||
|
||||
# Run progressive training
|
||||
print("\nStarting progressive training...")
|
||||
trainer.run_progressive_training()
|
||||
|
||||
print("\nTraining completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
88
src/data_utils.py
Normal file
88
src/data_utils.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
import json
|
||||
import jsonlines
|
||||
from typing import List, Dict
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
|
||||
def create_think_tag_example(question: str, reasoning: str, answer: str) -> Dict:
|
||||
"""Create training example with think tags"""
|
||||
output = f"<think>\n{reasoning}\n</think>\n\n{answer}"
|
||||
|
||||
return {
|
||||
"input": question,
|
||||
"output": output
|
||||
}
|
||||
|
||||
|
||||
def prepare_basic_cot_data(output_dir: str, num_examples: int = 1000):
|
||||
"""Create basic Chain-of-Thought examples"""
|
||||
output_path = Path(output_dir) / "basic_cot"
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
examples = []
|
||||
|
||||
# Simple arithmetic examples
|
||||
for i in range(num_examples // 2):
|
||||
a = random.randint(10, 100)
|
||||
b = random.randint(10, 100)
|
||||
question = f"What is {a} + {b}?"
|
||||
reasoning = f"To find {a} + {b}, I need to add these two numbers together.\n{a} + {b} = {a + b}"
|
||||
answer = f"The answer is {a + b}."
|
||||
|
||||
examples.append(create_think_tag_example(question, reasoning, answer))
|
||||
|
||||
# Simple word problems
|
||||
templates = [
|
||||
{
|
||||
"question": "If I have {a} apples and buy {b} more, how many apples do I have?",
|
||||
"reasoning": "Starting with {a} apples, then adding {b} more apples.\nTotal: {a} + {b} = {result}",
|
||||
"answer": "I have {result} apples."
|
||||
},
|
||||
{
|
||||
"question": "A store has {a} items. If {b} are sold, how many remain?",
|
||||
"reasoning": "Starting amount: {a} items\nSold: {b} items\nRemaining: {a} - {b} = {result}",
|
||||
"answer": "There are {result} items remaining."
|
||||
}
|
||||
]
|
||||
|
||||
for i in range(num_examples // 2):
|
||||
template = random.choice(templates)
|
||||
a = random.randint(20, 200)
|
||||
b = random.randint(10, min(50, a))
|
||||
|
||||
if "+" in template["reasoning"]:
|
||||
result = a + b
|
||||
else:
|
||||
result = a - b
|
||||
|
||||
question = template["question"].format(a=a, b=b)
|
||||
reasoning = template["reasoning"].format(a=a, b=b, result=result)
|
||||
answer = template["answer"].format(result=result)
|
||||
|
||||
examples.append(create_think_tag_example(question, reasoning, answer))
|
||||
|
||||
# Save to jsonl
|
||||
output_file = output_path / "train.jsonl"
|
||||
with jsonlines.open(output_file, "w") as writer:
|
||||
writer.write_all(examples)
|
||||
|
||||
print(f"Created {len(examples)} basic CoT examples at: {output_file}")
|
||||
|
||||
|
||||
def prepare_sample_datasets(base_dir: str = "./data"):
|
||||
"""Prepare sample datasets for all stages"""
|
||||
base_path = Path(base_dir)
|
||||
|
||||
# Basic CoT
|
||||
prepare_basic_cot_data(base_path)
|
||||
|
||||
# Math reasoning (placeholder)
|
||||
math_path = base_path / "math_reasoning"
|
||||
math_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Complex reasoning (placeholder)
|
||||
complex_path = base_path / "complex_reasoning"
|
||||
complex_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Sample datasets prepared in: {base_path}")
|
||||
366
src/progressive_model.py
Normal file
366
src/progressive_model.py
Normal file
|
|
@ -0,0 +1,366 @@
|
|||
import torch
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
TrainingArguments
|
||||
)
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
PeftModel,
|
||||
TaskType,
|
||||
get_peft_model,
|
||||
prepare_model_for_kbit_training
|
||||
)
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ProgressiveReasoningModel:
|
||||
"""Progressive training approach for reasoning models"""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.base_model_name = config["experiment"]["base_model"]
|
||||
self.output_dir = Path(config["experiment"]["output_dir"])
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.adapters = {}
|
||||
self.training_history = []
|
||||
|
||||
def setup_base_model(self):
|
||||
"""Initialize base model with quantization"""
|
||||
print(f"Loading base model: {self.base_model_name}")
|
||||
|
||||
# Check if quantization is enabled
|
||||
if self.config["model"].get("load_in_4bit", False):
|
||||
# BitsAndBytes config for 4-bit quantization
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=getattr(torch, self.config["model"]["bnb_4bit_compute_dtype"]),
|
||||
bnb_4bit_use_double_quant=self.config["model"]["bnb_4bit_use_double_quant"],
|
||||
bnb_4bit_quant_type=self.config["model"].get("bnb_4bit_quant_type", "nf4")
|
||||
)
|
||||
quantization_config = bnb_config
|
||||
else:
|
||||
quantization_config = None
|
||||
|
||||
# Model loading arguments
|
||||
model_kwargs = {
|
||||
"device_map": self.config["model"]["device_map"],
|
||||
"trust_remote_code": True,
|
||||
"torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
||||
}
|
||||
|
||||
# Add authentication token if provided
|
||||
if "hf_token" in self.config["model"] and self.config["model"]["hf_token"]:
|
||||
model_kwargs["token"] = self.config["model"]["hf_token"]
|
||||
|
||||
# Add max_memory configuration for CPU offloading
|
||||
if "max_memory" in self.config["model"]:
|
||||
model_kwargs["max_memory"] = self.config["model"]["max_memory"]
|
||||
print(f"Using max_memory configuration: {model_kwargs['max_memory']}")
|
||||
|
||||
# Add offload folder if specified
|
||||
if "offload_folder" in self.config["model"]:
|
||||
model_kwargs["offload_folder"] = self.config["model"]["offload_folder"]
|
||||
model_kwargs["offload_state_dict"] = True
|
||||
print(f"Using offload folder: {model_kwargs['offload_folder']}")
|
||||
|
||||
# Note: llm_int8_enable_fp32_cpu_offload is not supported for all models
|
||||
# Only add it if we're not using Gemma models
|
||||
if (quantization_config and
|
||||
self.config["model"].get("llm_int8_enable_fp32_cpu_offload", False) and
|
||||
"gemma" not in self.base_model_name.lower()):
|
||||
model_kwargs["llm_int8_enable_fp32_cpu_offload"] = True
|
||||
print("Enabled FP32 CPU offload for quantized model")
|
||||
|
||||
# Add quantization config if enabled
|
||||
if quantization_config:
|
||||
model_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
# Add attention implementation
|
||||
if self.config["model"].get("use_flash_attention_2", False):
|
||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
elif self.config["model"].get("use_eager_attention", False):
|
||||
model_kwargs["attn_implementation"] = "eager"
|
||||
|
||||
# Load model
|
||||
print("Loading model with the following kwargs:")
|
||||
for k, v in model_kwargs.items():
|
||||
if k != "quantization_config":
|
||||
print(f" {k}: {v}")
|
||||
else:
|
||||
print(f" {k}: <BitsAndBytesConfig>")
|
||||
|
||||
try:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.base_model_name,
|
||||
**model_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {e}")
|
||||
# Try without some problematic kwargs
|
||||
if "offload_folder" in model_kwargs:
|
||||
print("Retrying without offload_folder...")
|
||||
del model_kwargs["offload_folder"]
|
||||
del model_kwargs["offload_state_dict"]
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.base_model_name,
|
||||
**model_kwargs
|
||||
)
|
||||
|
||||
# Prepare for k-bit training if using quantization
|
||||
if quantization_config:
|
||||
self.model = prepare_model_for_kbit_training(self.model)
|
||||
|
||||
# Disable gradient checkpointing for now to avoid conflicts
|
||||
# Enable gradient checkpointing if requested (but disable use_cache)
|
||||
# if self.config["model"].get("gradient_checkpointing", False):
|
||||
# self.model.gradient_checkpointing_enable()
|
||||
# self.model.config.use_cache = False
|
||||
# print("Gradient checkpointing enabled, use_cache disabled")
|
||||
|
||||
# Explicitly disable use_cache to avoid conflicts
|
||||
if hasattr(self.model, 'config'):
|
||||
self.model.config.use_cache = False
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer_kwargs = {"trust_remote_code": True}
|
||||
if "hf_token" in self.config["model"] and self.config["model"]["hf_token"]:
|
||||
tokenizer_kwargs["token"] = self.config["model"]["hf_token"]
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.base_model_name,
|
||||
**tokenizer_kwargs
|
||||
)
|
||||
|
||||
# Set padding token and other special tokens for Gemma
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
# For Gemma models, ensure special tokens are set
|
||||
if "gemma" in self.base_model_name.lower():
|
||||
print("Configuring Gemma-specific tokenizer settings")
|
||||
# Add special tokens if they don't exist
|
||||
special_tokens = {
|
||||
"bos_token": "<bos>",
|
||||
"eos_token": "<eos>",
|
||||
"pad_token": "<pad>",
|
||||
}
|
||||
|
||||
# Only add tokens that don't already exist
|
||||
tokens_to_add = {}
|
||||
for token_name, token_value in special_tokens.items():
|
||||
if getattr(self.tokenizer, token_name, None) is None:
|
||||
tokens_to_add[token_name] = token_value
|
||||
|
||||
if tokens_to_add:
|
||||
num_added = self.tokenizer.add_special_tokens(tokens_to_add)
|
||||
print(f"Added special tokens: {tokens_to_add}")
|
||||
if num_added > 0:
|
||||
# Resize model embeddings to accommodate new tokens
|
||||
self.model.resize_token_embeddings(len(self.tokenizer))
|
||||
print(f"Resized model embeddings to {len(self.tokenizer)} tokens")
|
||||
|
||||
# Set appropriate model_max_length for Gemma
|
||||
if hasattr(self.tokenizer, 'model_max_length') and self.tokenizer.model_max_length > 8192:
|
||||
self.tokenizer.model_max_length = 8192
|
||||
print(f"Set tokenizer model_max_length to {self.tokenizer.model_max_length}")
|
||||
|
||||
# Debug: print model structure for target module identification
|
||||
print("Model structure:")
|
||||
for name, module in self.model.named_modules():
|
||||
if any(target in name for target in ['attn', 'proj', 'mlp', 'gate', 'up', 'down']):
|
||||
print(f" {name}: {type(module).__name__}")
|
||||
|
||||
print("Base model loaded successfully")
|
||||
|
||||
def get_target_modules(self, suggested_modules):
|
||||
"""Auto-detect valid target modules for the model"""
|
||||
valid_modules = []
|
||||
all_modules = [name for name, _ in self.model.named_modules()]
|
||||
|
||||
# Check each suggested module
|
||||
for module_name in suggested_modules:
|
||||
# Find modules that contain this name
|
||||
matching_modules = [name for name in all_modules if module_name in name]
|
||||
if matching_modules:
|
||||
valid_modules.append(module_name)
|
||||
print(f" Found target module: {module_name} (matches: {len(matching_modules)} modules)")
|
||||
else:
|
||||
print(f" Warning: target module '{module_name}' not found in model")
|
||||
|
||||
# If no valid modules found, try common alternatives
|
||||
if not valid_modules:
|
||||
print(" No suggested modules found, trying common alternatives...")
|
||||
common_alternatives = [
|
||||
"q_proj", "k_proj", "v_proj", "o_proj", # Common attention
|
||||
"gate_proj", "up_proj", "down_proj", # Common MLP
|
||||
"self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj", "self_attn.o_proj", # Full path
|
||||
"mlp.gate_proj", "mlp.up_proj", "mlp.down_proj", # Full MLP path
|
||||
]
|
||||
|
||||
for module_name in common_alternatives:
|
||||
matching_modules = [name for name in all_modules if module_name in name]
|
||||
if matching_modules:
|
||||
valid_modules.append(module_name)
|
||||
print(f" Found alternative target module: {module_name}")
|
||||
if len(valid_modules) >= 2: # At least 2 modules
|
||||
break
|
||||
|
||||
if not valid_modules:
|
||||
print(" ERROR: No valid target modules found!")
|
||||
print(" Available modules containing 'proj' or 'attn':")
|
||||
for name in all_modules:
|
||||
if any(keyword in name.lower() for keyword in ['proj', 'attn', 'mlp']):
|
||||
print(f" {name}")
|
||||
# Fallback to a basic module that should exist
|
||||
valid_modules = ["embed_tokens"]
|
||||
|
||||
return valid_modules
|
||||
|
||||
def create_adapter(self, stage_config: dict) -> LoraConfig:
|
||||
"""Create LoRA adapter configuration"""
|
||||
adapter_config = stage_config["adapter_config"]
|
||||
|
||||
# Get initialization method from config, default to True for identity init
|
||||
init_method = adapter_config.get("init_lora_weights", True)
|
||||
|
||||
# Auto-detect valid target modules
|
||||
suggested_modules = adapter_config["target_modules"]
|
||||
valid_modules = self.get_target_modules(suggested_modules)
|
||||
|
||||
print(f"Using target modules: {valid_modules}")
|
||||
|
||||
return LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
r=adapter_config["r"],
|
||||
lora_alpha=adapter_config["lora_alpha"],
|
||||
lora_dropout=adapter_config["lora_dropout"],
|
||||
target_modules=valid_modules,
|
||||
bias="none",
|
||||
init_lora_weights=init_method # Initialize LoRA weights (True = identity, "gaussian" = random)
|
||||
)
|
||||
|
||||
def add_progressive_adapter(self, stage_name: str, stage_config: dict):
|
||||
"""Add a new adapter for progressive training"""
|
||||
print(f"\nAdding adapter for stage: {stage_name}")
|
||||
|
||||
# Check if we should inherit from previous adapter
|
||||
if "inherit_from" in stage_config and stage_config["inherit_from"] in self.adapters:
|
||||
print(f"Inheriting from: {stage_config['inherit_from']}")
|
||||
# Load previous adapter as base
|
||||
prev_adapter_path = self.adapters[stage_config["inherit_from"]]
|
||||
self.model = PeftModel.from_pretrained(
|
||||
self.model,
|
||||
prev_adapter_path,
|
||||
is_trainable=True
|
||||
)
|
||||
# Merge and unload to incorporate previous learning
|
||||
self.model = self.model.merge_and_unload()
|
||||
|
||||
# Create new adapter config
|
||||
lora_config = self.create_adapter(stage_config)
|
||||
|
||||
# Add adapter to model
|
||||
self.model = get_peft_model(self.model, lora_config)
|
||||
|
||||
# Ensure model is in training mode
|
||||
self.model.train()
|
||||
|
||||
# Print trainable parameters
|
||||
self.model.print_trainable_parameters()
|
||||
|
||||
# Debug: check if any parameters require gradients
|
||||
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
total_params = sum(p.numel() for p in self.model.parameters())
|
||||
print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")
|
||||
|
||||
# List parameters that require gradients
|
||||
grad_params = [name for name, param in self.model.named_parameters() if param.requires_grad]
|
||||
print(f"Parameters requiring gradients: {len(grad_params)} parameters")
|
||||
if len(grad_params) > 0:
|
||||
print(f"First few: {grad_params[:5]}")
|
||||
else:
|
||||
print("WARNING: No parameters require gradients!")
|
||||
|
||||
# Save adapter path
|
||||
adapter_path = self.output_dir / "adapters" / stage_name
|
||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||
self.adapters[stage_name] = str(adapter_path)
|
||||
|
||||
def save_adapter(self, stage_name: str):
|
||||
"""Save current adapter"""
|
||||
if stage_name in self.adapters:
|
||||
print(f"Saving adapter: {stage_name}")
|
||||
self.model.save_pretrained(self.adapters[stage_name])
|
||||
# Also save tokenizer for convenience
|
||||
self.tokenizer.save_pretrained(self.adapters[stage_name])
|
||||
|
||||
def load_for_inference(self, adapter_names: List[str], weights: Optional[Dict[str, float]] = None):
|
||||
"""Load model with specific adapters for inference"""
|
||||
if len(adapter_names) == 1:
|
||||
# Single adapter
|
||||
adapter_name = adapter_names[0]
|
||||
|
||||
# Check if adapter path is in memory
|
||||
if adapter_name in self.adapters:
|
||||
adapter_path = self.adapters[adapter_name]
|
||||
else:
|
||||
# Try to find adapter in output directory
|
||||
adapter_path = self.output_dir / "adapters" / adapter_name
|
||||
if not adapter_path.exists():
|
||||
raise ValueError(f"Adapter {adapter_name} not found at {adapter_path}")
|
||||
adapter_path = str(adapter_path)
|
||||
|
||||
print(f"Loading adapter from: {adapter_path}")
|
||||
self.model = PeftModel.from_pretrained(
|
||||
self.model,
|
||||
adapter_path
|
||||
)
|
||||
else:
|
||||
# Multiple adapters - load and combine
|
||||
# This is a simplified version - real implementation would need adapter composition
|
||||
print("Multi-adapter inference not fully implemented in this bootstrap")
|
||||
# For now, just load the last adapter
|
||||
adapter_name = adapter_names[-1]
|
||||
if adapter_name in self.adapters:
|
||||
adapter_path = self.adapters[adapter_name]
|
||||
else:
|
||||
adapter_path = str(self.output_dir / "adapters" / adapter_name)
|
||||
self.model = PeftModel.from_pretrained(
|
||||
self.model,
|
||||
adapter_path
|
||||
)
|
||||
|
||||
def generate_with_reasoning(self, prompt: str, max_length: int = 2048) -> str:
|
||||
"""Generate response with reasoning"""
|
||||
# Format prompt with think tags expectation
|
||||
formatted_prompt = f"{prompt}\n\nPlease think step by step using <think> tags before providing your answer."
|
||||
|
||||
# Tokenize
|
||||
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
max_length=max_length,
|
||||
temperature=0.7,
|
||||
do_sample=True,
|
||||
top_p=0.95,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
eos_token_id=self.tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# Decode
|
||||
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
# Extract response after prompt
|
||||
response = response[len(formatted_prompt):].strip()
|
||||
|
||||
return response
|
||||
450
src/training.py
Normal file
450
src/training.py
Normal file
|
|
@ -0,0 +1,450 @@
|
|||
from transformers import TrainingArguments
|
||||
from trl import SFTTrainer
|
||||
from datasets import load_dataset, Dataset
|
||||
import torch
|
||||
from typing import Dict, List
|
||||
import json
|
||||
import jsonlines
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ProgressiveTrainer:
|
||||
"""Handle progressive training stages"""
|
||||
|
||||
def __init__(self, model_wrapper, config: dict):
|
||||
self.model_wrapper = model_wrapper
|
||||
self.config = config
|
||||
self.training_history = []
|
||||
|
||||
def load_dataset(self, dataset_path: str, stage_config: dict = None) -> Dataset:
|
||||
"""Load dataset from jsonl files or HuggingFace datasets"""
|
||||
print(f"Loading dataset from path: {dataset_path}")
|
||||
|
||||
# Check if it's a HuggingFace dataset (contains '/')
|
||||
if '/' in dataset_path and not Path(dataset_path).exists():
|
||||
print(f"Loading HuggingFace dataset: {dataset_path}")
|
||||
return self.load_huggingface_dataset(dataset_path, stage_config)
|
||||
|
||||
# Load local dataset
|
||||
data = []
|
||||
print(f"Current working directory: {Path.cwd()}")
|
||||
|
||||
# Support both single file and directory
|
||||
path = Path(dataset_path)
|
||||
print(f"Path exists: {path.exists()}")
|
||||
print(f"Is file: {path.is_file()}")
|
||||
print(f"Is directory: {path.is_dir()}")
|
||||
|
||||
if path.is_file():
|
||||
files = [path]
|
||||
else:
|
||||
files = list(path.glob("*.jsonl"))
|
||||
|
||||
print(f"Found {len(files)} files to load")
|
||||
for f in files:
|
||||
print(f" - {f}")
|
||||
|
||||
for file_path in files:
|
||||
print(f"Loading file: {file_path}")
|
||||
try:
|
||||
with jsonlines.open(file_path) as reader:
|
||||
count = 0
|
||||
for item in reader:
|
||||
# Format for chat template
|
||||
formatted = {
|
||||
"messages": [
|
||||
{"role": "user", "content": item["input"]},
|
||||
{"role": "assistant", "content": item["output"]}
|
||||
]
|
||||
}
|
||||
data.append(formatted)
|
||||
count += 1
|
||||
print(f" Loaded {count} examples from {file_path}")
|
||||
except Exception as e:
|
||||
print(f" Error loading file {file_path}: {e}")
|
||||
|
||||
print(f"Total examples loaded: {len(data)}")
|
||||
return Dataset.from_list(data)
|
||||
|
||||
def load_huggingface_dataset(self, dataset_name: str, stage_config: dict) -> Dataset:
|
||||
"""Load dataset from HuggingFace"""
|
||||
try:
|
||||
dataset_config = stage_config.get("dataset_config", {}) if stage_config else {}
|
||||
|
||||
# Default settings
|
||||
split = dataset_config.get("split", "train")
|
||||
max_samples = dataset_config.get("max_samples", None)
|
||||
streaming = dataset_config.get("streaming", False)
|
||||
|
||||
print(f"Loading HuggingFace dataset: {dataset_name}")
|
||||
print(f" Split: {split}")
|
||||
print(f" Max samples: {max_samples}")
|
||||
print(f" Streaming: {streaming}")
|
||||
|
||||
# Load dataset
|
||||
if streaming:
|
||||
dataset = load_dataset(dataset_name, split=split, streaming=True)
|
||||
if max_samples:
|
||||
dataset = dataset.take(max_samples)
|
||||
# Convert streaming dataset to regular dataset
|
||||
data = []
|
||||
count = 0
|
||||
for item in dataset:
|
||||
data.append(item)
|
||||
count += 1
|
||||
if count % 1000 == 0:
|
||||
print(f" Loaded {count} examples...")
|
||||
if max_samples and count >= max_samples:
|
||||
break
|
||||
dataset = Dataset.from_list(data)
|
||||
else:
|
||||
dataset = load_dataset(dataset_name, split=split)
|
||||
if max_samples:
|
||||
dataset = dataset.select(range(min(max_samples, len(dataset))))
|
||||
|
||||
print(f" Loaded dataset with {len(dataset)} examples")
|
||||
print(f" Dataset columns: {dataset.column_names}")
|
||||
if len(dataset) > 0:
|
||||
print(f" First example: {dataset[0]}")
|
||||
|
||||
# Convert to our expected format based on dataset name
|
||||
if "math" in dataset_name.lower():
|
||||
return self.convert_math_dataset(dataset)
|
||||
elif "mixture-of-thoughts" in dataset_name.lower():
|
||||
return self.convert_mixture_of_thoughts_dataset(dataset)
|
||||
else:
|
||||
return self.convert_generic_dataset(dataset)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading HuggingFace dataset {dataset_name}: {e}")
|
||||
print("Falling back to empty dataset")
|
||||
return Dataset.from_list([])
|
||||
|
||||
def convert_math_dataset(self, dataset: Dataset) -> Dataset:
|
||||
"""Convert OpenR1-Math-220k format to our training format"""
|
||||
def format_math_example(example):
|
||||
# OpenR1-Math-220k format has different column names
|
||||
# Try to find the right columns
|
||||
input_text = None
|
||||
output_text = None
|
||||
|
||||
# Common column names in math datasets
|
||||
if "question" in example:
|
||||
input_text = example["question"]
|
||||
elif "problem" in example:
|
||||
input_text = example["problem"]
|
||||
elif "input" in example:
|
||||
input_text = example["input"]
|
||||
elif "query" in example:
|
||||
input_text = example["query"]
|
||||
|
||||
if "answer" in example:
|
||||
output_text = example["answer"]
|
||||
elif "solution" in example:
|
||||
output_text = example["solution"]
|
||||
elif "output" in example:
|
||||
output_text = example["output"]
|
||||
elif "response" in example:
|
||||
output_text = example["response"]
|
||||
|
||||
# If we can't find the right columns, use the raw example
|
||||
if input_text is None or output_text is None:
|
||||
print(f"Warning: Could not parse example columns: {list(example.keys())}")
|
||||
# Try to use the first two string fields
|
||||
string_fields = [k for k, v in example.items() if isinstance(v, str) and len(v) > 10]
|
||||
if len(string_fields) >= 2:
|
||||
input_text = example[string_fields[0]]
|
||||
output_text = example[string_fields[1]]
|
||||
else:
|
||||
# Skip this example
|
||||
return None
|
||||
|
||||
# Format with think tags for math reasoning
|
||||
formatted_output = f"<think>\nLet me solve this step by step.\n\n{output_text}\n</think>\n\n{output_text}"
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "user", "content": input_text},
|
||||
{"role": "assistant", "content": formatted_output}
|
||||
]
|
||||
}
|
||||
|
||||
# Convert and filter out None results
|
||||
converted = dataset.map(format_math_example, desc="Converting math dataset")
|
||||
converted = converted.filter(lambda x: x is not None, desc="Filtering valid examples")
|
||||
|
||||
print(f"Converted {len(converted)} math examples")
|
||||
if len(converted) > 0:
|
||||
print(f"First converted example: {converted[0]}")
|
||||
|
||||
return converted
|
||||
|
||||
def convert_mixture_of_thoughts_dataset(self, dataset: Dataset) -> Dataset:
|
||||
"""Convert Mixture-of-Thoughts format to our training format"""
|
||||
def format_mot_example(example):
|
||||
# Mixture-of-Thoughts typically has complex reasoning patterns
|
||||
# Check for common column names in the dataset
|
||||
input_text = None
|
||||
output_text = None
|
||||
|
||||
# Try to identify input/output columns
|
||||
if "prompt" in example:
|
||||
input_text = example["prompt"]
|
||||
elif "question" in example:
|
||||
input_text = example["question"]
|
||||
elif "input" in example:
|
||||
input_text = example["input"]
|
||||
elif "instruction" in example:
|
||||
input_text = example["instruction"]
|
||||
|
||||
if "response" in example:
|
||||
output_text = example["response"]
|
||||
elif "output" in example:
|
||||
output_text = example["output"]
|
||||
elif "completion" in example:
|
||||
output_text = example["completion"]
|
||||
elif "answer" in example:
|
||||
output_text = example["answer"]
|
||||
|
||||
# If columns not found, look for thinking patterns
|
||||
if input_text is None or output_text is None:
|
||||
# Try to find columns with substantial text
|
||||
for key, value in example.items():
|
||||
if isinstance(value, str) and len(value) > 20:
|
||||
if input_text is None and any(q in key.lower() for q in ["prompt", "question", "input"]):
|
||||
input_text = value
|
||||
elif output_text is None and any(a in key.lower() for a in ["response", "answer", "output"]):
|
||||
output_text = value
|
||||
|
||||
if input_text is None or output_text is None:
|
||||
print(f"Warning: Could not parse MoT example columns: {list(example.keys())}")
|
||||
return None
|
||||
|
||||
# Check if output already contains thinking tags
|
||||
if "<think>" in output_text or "思考" in output_text:
|
||||
# Already formatted with thinking
|
||||
formatted_output = output_text
|
||||
else:
|
||||
# Add thinking structure for complex reasoning
|
||||
formatted_output = f"<think>\nLet me break this down step by step.\n\n{output_text}\n</think>\n\nBased on my analysis, {output_text}"
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "user", "content": input_text},
|
||||
{"role": "assistant", "content": formatted_output}
|
||||
]
|
||||
}
|
||||
|
||||
# Convert and filter
|
||||
converted = dataset.map(format_mot_example, desc="Converting Mixture-of-Thoughts dataset")
|
||||
converted = converted.filter(lambda x: x is not None, desc="Filtering valid examples")
|
||||
|
||||
print(f"Converted {len(converted)} Mixture-of-Thoughts examples")
|
||||
if len(converted) > 0:
|
||||
print(f"First converted example: {converted[0]}")
|
||||
|
||||
return converted
|
||||
|
||||
def convert_generic_dataset(self, dataset: Dataset) -> Dataset:
|
||||
"""Convert generic dataset format to our training format"""
|
||||
def format_generic_example(example):
|
||||
# Generic conversion for unknown dataset formats
|
||||
input_text = None
|
||||
output_text = None
|
||||
|
||||
# Look for any text columns
|
||||
text_columns = [(k, v) for k, v in example.items() if isinstance(v, str) and len(v) > 10]
|
||||
|
||||
if len(text_columns) >= 2:
|
||||
# Use first two substantial text columns
|
||||
input_text = text_columns[0][1]
|
||||
output_text = text_columns[1][1]
|
||||
elif len(text_columns) == 1:
|
||||
# Only one text column - skip this example
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "user", "content": input_text},
|
||||
{"role": "assistant", "content": output_text}
|
||||
]
|
||||
}
|
||||
|
||||
converted = dataset.map(format_generic_example, desc="Converting generic dataset")
|
||||
converted = converted.filter(lambda x: x is not None, desc="Filtering valid examples")
|
||||
|
||||
print(f"Converted {len(converted)} generic examples")
|
||||
return converted
|
||||
|
||||
def format_dataset(self, dataset: Dataset) -> Dataset:
|
||||
"""Format dataset for training"""
|
||||
print(f"Dataset before formatting: {len(dataset)} examples")
|
||||
print(f"First example: {dataset[0] if len(dataset) > 0 else 'No data'}")
|
||||
|
||||
# Check if tokenizer has chat template
|
||||
has_chat_template = (
|
||||
hasattr(self.model_wrapper.tokenizer, 'chat_template') and
|
||||
self.model_wrapper.tokenizer.chat_template is not None
|
||||
)
|
||||
|
||||
if not has_chat_template:
|
||||
print("No chat template found, setting default Gemma chat template")
|
||||
# Set a simple chat template for Gemma
|
||||
self.model_wrapper.tokenizer.chat_template = "{% for message in messages %}<start_of_turn>{{ message['role'] }}\n{{ message['content'] }}<end_of_turn>\n{% endfor %}"
|
||||
|
||||
def format_chat(example):
|
||||
# Try to use chat template if available
|
||||
if has_chat_template or self.model_wrapper.tokenizer.chat_template:
|
||||
try:
|
||||
text = self.model_wrapper.tokenizer.apply_chat_template(
|
||||
example["messages"],
|
||||
tokenize=False,
|
||||
add_generation_prompt=False
|
||||
)
|
||||
return {"text": text}
|
||||
except Exception as e:
|
||||
print(f"Chat template failed: {e}, using fallback")
|
||||
|
||||
# Fallback: create simple formatted text
|
||||
if "messages" in example:
|
||||
user_msg = example["messages"][0]["content"]
|
||||
assistant_msg = example["messages"][1]["content"]
|
||||
return {"text": f"<start_of_turn>user\n{user_msg}<end_of_turn>\n<start_of_turn>model\n{assistant_msg}<end_of_turn>\n"}
|
||||
elif "input" in example and "output" in example:
|
||||
return {"text": f"<start_of_turn>user\n{example['input']}<end_of_turn>\n<start_of_turn>model\n{example['output']}<end_of_turn>\n"}
|
||||
else:
|
||||
return {"text": str(example)}
|
||||
|
||||
# Format dataset
|
||||
formatted = dataset.map(format_chat, batched=False, desc="Formatting dataset")
|
||||
print(f"Dataset after formatting: {len(formatted)} examples")
|
||||
if len(formatted) > 0:
|
||||
print(f"Columns: {formatted.column_names}")
|
||||
print(f"First formatted example: {formatted[0]}")
|
||||
|
||||
# Keep only the 'text' column for SFTTrainer
|
||||
if 'text' in formatted.column_names:
|
||||
columns_to_remove = [col for col in formatted.column_names if col != 'text']
|
||||
if columns_to_remove:
|
||||
formatted = formatted.remove_columns(columns_to_remove)
|
||||
|
||||
return formatted
|
||||
|
||||
def filter_by_length(self, dataset: Dataset, max_length: int) -> Dataset:
|
||||
"""Filter dataset by sequence length"""
|
||||
def is_valid_length(example):
|
||||
# Tokenize and check length
|
||||
tokens = self.model_wrapper.tokenizer(
|
||||
example["text"],
|
||||
truncation=False,
|
||||
return_length=True
|
||||
)
|
||||
return len(tokens["input_ids"]) <= max_length
|
||||
|
||||
filtered = dataset.filter(is_valid_length, desc="Filtering by length")
|
||||
print(f"Filtered dataset: {len(filtered)} examples (max_length={max_length})")
|
||||
return filtered
|
||||
|
||||
def train_stage(self, stage_name: str, stage_config: dict):
|
||||
"""Train a single stage"""
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Training stage: {stage_name}")
|
||||
print(f"Description: {stage_config['description']}")
|
||||
print(f"{'='*50}\n")
|
||||
|
||||
# Add adapter
|
||||
self.model_wrapper.add_progressive_adapter(stage_name, stage_config)
|
||||
|
||||
# Load and format dataset
|
||||
dataset = self.load_dataset(stage_config["dataset_path"], stage_config)
|
||||
dataset = self.format_dataset(dataset)
|
||||
|
||||
# Filter by sequence length if specified
|
||||
if "max_length" in stage_config["training"]:
|
||||
dataset = self.filter_by_length(dataset, stage_config["training"]["max_length"])
|
||||
|
||||
print(f"Final dataset size: {len(dataset)} examples")
|
||||
|
||||
# Training arguments - with CPU offload optimizations
|
||||
training_args = TrainingArguments(
|
||||
output_dir=f"./outputs/checkpoints/{stage_name}",
|
||||
num_train_epochs=stage_config["training"]["num_epochs"],
|
||||
per_device_train_batch_size=stage_config["training"]["per_device_batch_size"],
|
||||
gradient_accumulation_steps=stage_config["training"]["gradient_accumulation_steps"],
|
||||
learning_rate=float(stage_config["training"]["learning_rate"]), # Ensure it's a float
|
||||
warmup_steps=stage_config["training"]["warmup_steps"],
|
||||
logging_steps=stage_config["training"].get("logging_steps", 10),
|
||||
save_strategy="epoch",
|
||||
eval_strategy="no",
|
||||
bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
|
||||
gradient_checkpointing=self.config["model"].get("gradient_checkpointing", False),
|
||||
max_grad_norm=stage_config["training"].get("max_grad_norm", 1.0),
|
||||
report_to="wandb" if self.config["experiment"]["use_wandb"] else "none",
|
||||
run_name=f"{self.config['experiment']['name']}_{stage_name}",
|
||||
dataloader_pin_memory=False, # Reduce memory usage
|
||||
remove_unused_columns=False, # Keep all columns
|
||||
optim=stage_config["training"].get("optim", "adamw_torch"), # Support 8-bit optimizers
|
||||
dataloader_num_workers=stage_config["training"].get("dataloader_num_workers", 2),
|
||||
)
|
||||
|
||||
# Print dataset info for debugging
|
||||
print(f"Dataset columns: {dataset.column_names}")
|
||||
print(f"Dataset first example: {dataset[0]}")
|
||||
|
||||
# Ensure model is in training mode before creating trainer
|
||||
self.model_wrapper.model.train()
|
||||
|
||||
# Final check of trainable parameters
|
||||
trainable_params = sum(p.numel() for p in self.model_wrapper.model.parameters() if p.requires_grad)
|
||||
print(f"Final check - Trainable parameters: {trainable_params:,}")
|
||||
|
||||
# Create trainer with minimal configuration
|
||||
try:
|
||||
trainer = SFTTrainer(
|
||||
model=self.model_wrapper.model,
|
||||
processing_class=self.model_wrapper.tokenizer,
|
||||
train_dataset=dataset,
|
||||
args=training_args,
|
||||
packing=False, # Disable packing for better gradient flow
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error creating SFTTrainer: {e}")
|
||||
print("Trying with basic configuration...")
|
||||
trainer = SFTTrainer(
|
||||
model=self.model_wrapper.model,
|
||||
processing_class=self.model_wrapper.tokenizer,
|
||||
train_dataset=dataset,
|
||||
args=training_args,
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.train()
|
||||
|
||||
# Save adapter
|
||||
self.model_wrapper.save_adapter(stage_name)
|
||||
|
||||
# Record history
|
||||
self.training_history.append({
|
||||
"stage": stage_name,
|
||||
"config": stage_config,
|
||||
"metrics": trainer.state.log_history
|
||||
})
|
||||
|
||||
print(f"\nCompleted training stage: {stage_name}")
|
||||
|
||||
def run_progressive_training(self):
|
||||
"""Run all training stages progressively"""
|
||||
stages = self.config["progressive_stages"]
|
||||
|
||||
for stage_config in stages:
|
||||
stage_name = stage_config["name"]
|
||||
self.train_stage(stage_name, stage_config)
|
||||
|
||||
# Save training history
|
||||
history_path = Path(self.config["experiment"]["output_dir"]) / "training_history.json"
|
||||
with open(history_path, "w") as f:
|
||||
json.dump(self.training_history, f, indent=2)
|
||||
|
||||
print(f"\nAll stages completed! Training history saved to: {history_path}")
|
||||
35
test_data_load.py
Normal file
35
test_data_load.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Test data loading"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.append(str(Path(__file__).parent))
|
||||
|
||||
from src.training import ProgressiveTrainer
|
||||
from src.progressive_model import ProgressiveReasoningModel
|
||||
import yaml
|
||||
|
||||
# Load config
|
||||
with open("config/training_config.yaml") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# Create dummy model wrapper
|
||||
class DummyModelWrapper:
|
||||
def __init__(self):
|
||||
self.tokenizer = None
|
||||
|
||||
model_wrapper = DummyModelWrapper()
|
||||
|
||||
# Create trainer
|
||||
trainer = ProgressiveTrainer(model_wrapper, config)
|
||||
|
||||
# Test data loading
|
||||
stage_config = config["progressive_stages"][0]
|
||||
dataset_path = stage_config["dataset_path"]
|
||||
print(f"Loading dataset from: {dataset_path}")
|
||||
|
||||
dataset = trainer.load_dataset(dataset_path)
|
||||
print(f"Loaded {len(dataset)} examples")
|
||||
|
||||
if len(dataset) > 0:
|
||||
print(f"First example: {dataset[0]}")
|
||||
Loading…
Add table
Reference in a new issue