commit 2c30e06f206299aee3b1603393571d6e4ddd61de Author: Soma Nakamura Date: Thu Jul 10 18:09:14 2025 +0900 initial diff --git a/.devenv.flake.nix b/.devenv.flake.nix new file mode 100644 index 0000000..b7b7adc --- /dev/null +++ b/.devenv.flake.nix @@ -0,0 +1,163 @@ +{ + inputs = + let + version = "1.6.1"; +system = "x86_64-linux"; +devenv_root = "/home/centra/dev/pnn/progressive-llm-training"; +devenv_dotfile = ./.devenv; +devenv_dotfile_string = ".devenv"; +container_name = null; +devenv_tmpdir = "/run/user/1000"; +devenv_runtime = "/run/user/1000/devenv-adeda32"; +devenv_istesting = false; +devenv_direnvrc_latest_version = 1; + + in { + git-hooks.url = "github:cachix/git-hooks.nix"; + git-hooks.inputs.nixpkgs.follows = "nixpkgs"; + pre-commit-hooks.follows = "git-hooks"; + nixpkgs.url = "github:cachix/devenv-nixpkgs/rolling"; + devenv.url = "github:cachix/devenv?dir=src/modules"; + } // (if builtins.pathExists (devenv_dotfile + "/flake.json") + then builtins.fromJSON (builtins.readFile (devenv_dotfile + "/flake.json")) + else { }); + + outputs = { nixpkgs, ... }@inputs: + let + version = "1.6.1"; +system = "x86_64-linux"; +devenv_root = "/home/centra/dev/pnn/progressive-llm-training"; +devenv_dotfile = ./.devenv; +devenv_dotfile_string = ".devenv"; +container_name = null; +devenv_tmpdir = "/run/user/1000"; +devenv_runtime = "/run/user/1000/devenv-adeda32"; +devenv_istesting = false; +devenv_direnvrc_latest_version = 1; + + devenv = + if builtins.pathExists (devenv_dotfile + "/devenv.json") + then builtins.fromJSON (builtins.readFile (devenv_dotfile + "/devenv.json")) + else { }; + getOverlays = inputName: inputAttrs: + map + (overlay: + let + input = inputs.${inputName} or (throw "No such input `${inputName}` while trying to configure overlays."); + in + input.overlays.${overlay} or (throw "Input `${inputName}` has no overlay called `${overlay}`. Supported overlays: ${nixpkgs.lib.concatStringsSep ", " (builtins.attrNames input.overlays)}")) + inputAttrs.overlays or [ ]; + overlays = nixpkgs.lib.flatten (nixpkgs.lib.mapAttrsToList getOverlays (devenv.inputs or { })); + pkgs = import nixpkgs { + inherit system; + config = { + allowUnfree = devenv.allowUnfree or false; + allowBroken = devenv.allowBroken or false; + permittedInsecurePackages = devenv.permittedInsecurePackages or [ ]; + }; + inherit overlays; + }; + lib = pkgs.lib; + importModule = path: + if lib.hasPrefix "./" path + then if lib.hasSuffix ".nix" path + then ./. + (builtins.substring 1 255 path) + else ./. + (builtins.substring 1 255 path) + "/devenv.nix" + else if lib.hasPrefix "../" path + then throw "devenv: ../ is not supported for imports" + else + let + paths = lib.splitString "/" path; + name = builtins.head paths; + input = inputs.${name} or (throw "Unknown input ${name}"); + subpath = "/${lib.concatStringsSep "/" (builtins.tail paths)}"; + devenvpath = "${input}" + subpath; + devenvdefaultpath = devenvpath + "/devenv.nix"; + in + if lib.hasSuffix ".nix" devenvpath + then devenvpath + else if builtins.pathExists devenvdefaultpath + then devenvdefaultpath + else throw (devenvdefaultpath + " file does not exist for input ${name}."); + project = pkgs.lib.evalModules { + specialArgs = inputs // { inherit inputs; }; + modules = [ + ({ config, ... }: { + _module.args.pkgs = pkgs.appendOverlays (config.overlays or [ ]); + }) + (inputs.devenv.modules + /top-level.nix) + { + devenv.cliVersion = version; + devenv.root = devenv_root; + devenv.dotfile = devenv_root + "/" + devenv_dotfile_string; + } + (pkgs.lib.optionalAttrs (inputs.devenv.isTmpDir or false) { + devenv.tmpdir = devenv_tmpdir; + devenv.runtime = devenv_runtime; + }) + (pkgs.lib.optionalAttrs (inputs.devenv.hasIsTesting or false) { + devenv.isTesting = devenv_istesting; + }) + (pkgs.lib.optionalAttrs (container_name != null) { + container.isBuilding = pkgs.lib.mkForce true; + containers.${container_name}.isBuilding = true; + }) + ({ options, ... }: { + config.devenv = pkgs.lib.optionalAttrs (builtins.hasAttr "direnvrcLatestVersion" options.devenv) { + direnvrcLatestVersion = devenv_direnvrc_latest_version; + }; + }) + ] ++ (map importModule (devenv.imports or [ ])) ++ [ + (if builtins.pathExists ./devenv.nix then ./devenv.nix else { }) + (devenv.devenv or { }) + (if builtins.pathExists ./devenv.local.nix then ./devenv.local.nix else { }) + (if builtins.pathExists (devenv_dotfile + "/cli-options.nix") then import (devenv_dotfile + "/cli-options.nix") else { }) + ]; + }; + config = project.config; + + options = pkgs.nixosOptionsDoc { + options = builtins.removeAttrs project.options [ "_module" ]; + warningsAreErrors = false; + # Unpack Nix types, e.g. literalExpression, mDoc. + transformOptions = + let isDocType = v: builtins.elem v [ "literalDocBook" "literalExpression" "literalMD" "mdDoc" ]; + in lib.attrsets.mapAttrs (_: v: + if v ? _type && isDocType v._type then + v.text + else if v ? _type && v._type == "derivation" then + v.name + else + v + ); + }; + + build = options: config: + lib.concatMapAttrs + (name: option: + if builtins.hasAttr "type" option then + if option.type.name == "output" || option.type.name == "outputOf" then { + ${name} = config.${name}; + } else { } + else + let v = build option config.${name}; + in if v != { } then { + ${name} = v; + } else { } + ) + options; + + systems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ]; + in + { + devShell = lib.genAttrs systems (system: config.shell); + packages = lib.genAttrs systems (system: { + optionsJSON = options.optionsJSON; + # deprecated + inherit (config) info procfileScript procfileEnv procfile; + ci = config.ciDerivation; + }); + devenv = config; + build = build project.options project.config; + }; + } diff --git a/.devenv/bash b/.devenv/bash new file mode 120000 index 0000000..3eab571 --- /dev/null +++ b/.devenv/bash @@ -0,0 +1 @@ +/nix/store/94lg0shvsfc845zy8gnflvpqxxiyijbz-bash-interactive-5.2p37 \ No newline at end of file diff --git a/.devenv/devenv.json b/.devenv/devenv.json new file mode 100644 index 0000000..bfa79af --- /dev/null +++ b/.devenv/devenv.json @@ -0,0 +1 @@ +{"inputs":{"nixpkgs":{"url":"github:NixOS/nixpkgs/nixos-unstable"},"nixpkgs-python":{"url":"github:cachix/nixpkgs-python","inputs":{"nixpkgs":{"follows":"nixpkgs"}}}},"allowUnfree":true} \ No newline at end of file diff --git a/.devenv/flake.json b/.devenv/flake.json new file mode 100644 index 0000000..c487dcb --- /dev/null +++ b/.devenv/flake.json @@ -0,0 +1 @@ +{"nixpkgs":{"url":"github:NixOS/nixpkgs/nixos-unstable"},"nixpkgs-python":{"url":"github:cachix/nixpkgs-python","inputs":{"nixpkgs":{"follows":"nixpkgs"}}}} \ No newline at end of file diff --git a/.devenv/gc/shell b/.devenv/gc/shell new file mode 120000 index 0000000..2b5306e --- /dev/null +++ b/.devenv/gc/shell @@ -0,0 +1 @@ +shell-1-link \ No newline at end of file diff --git a/.devenv/gc/shell-1-link b/.devenv/gc/shell-1-link new file mode 120000 index 0000000..eacdc2d --- /dev/null +++ b/.devenv/gc/shell-1-link @@ -0,0 +1 @@ +/nix/store/7fimdw1in7f1g0wxw5cr9pg26rs4rp5g-devenv-shell-env \ No newline at end of file diff --git a/.devenv/imports.txt b/.devenv/imports.txt new file mode 100644 index 0000000..e69de29 diff --git a/.devenv/input-paths.txt b/.devenv/input-paths.txt new file mode 100644 index 0000000..6d1c4e8 --- /dev/null +++ b/.devenv/input-paths.txt @@ -0,0 +1,11 @@ +/home/centra/.config/nixpkgs/config.nix +/home/centra/.config/nixpkgs/overlays +/home/centra/.config/nixpkgs/overlays.nix +/home/centra/.nixpkgs/config.nix +/home/centra/dev/pnn/progressive-llm-training/.devenv/flake.json +/home/centra/dev/pnn/progressive-llm-training/.devenv.flake.nix +/home/centra/dev/pnn/progressive-llm-training/.env +/home/centra/dev/pnn/progressive-llm-training/devenv.local.nix +/home/centra/dev/pnn/progressive-llm-training/devenv.lock +/home/centra/dev/pnn/progressive-llm-training/devenv.nix +/home/centra/dev/pnn/progressive-llm-training/devenv.yaml \ No newline at end of file diff --git a/.devenv/load-exports b/.devenv/load-exports new file mode 100755 index 0000000..c0b1498 --- /dev/null +++ b/.devenv/load-exports @@ -0,0 +1,3 @@ +export PATH='/home/centra/dev/pnn/progressive-llm-training/.devenv/state/venv/bin:/nix/store/bdqwd2frn9m7n3hj2436s0vlnv7mawpc-python3-3.11.13-env/bin:/nix/store/9w80x8njl1hcp8vlk1f3x17q4hcd2cqp-evaluate/bin:/nix/store/8df6wqahd2fqzl04kcs3xs32yqqimcxb-install-packages/bin:/nix/store/v5rz1h6ci23icfp6y228r2m0fqrdf408-install-packages-cpu/bin:/nix/store/69142b4sjmb4jffmyjb8nar6qzlgxnpg-prepare-data/bin:/nix/store/bhb6l6yfqknnwc7y5j5xc9k866hajv7b-train/bin:/nix/store/pbqah1qk4b5y14fqinr1h8zvhqy71v81-gcc-wrapper-14.3.0/bin:/nix/store/sa7j7cddyblhcb3ch3ds10w7nw75yjj1-gcc-14.3.0/bin:/nix/store/mdmsnfcvxyk5ynz7nx8nhss1wig0gljx-glibc-2.40-66-bin/bin:/nix/store/psy9v2asypgl9ylg8cnzkixc7fv0snj0-coreutils-9.7/bin:/nix/store/cadx5p7c0i06gf6h84iw9mrhx56imbv0-binutils-wrapper-2.44/bin:/nix/store/z3za8hfc24wb117s50p8b10agjkgm039-binutils-2.44/bin:/nix/store/dx4bdrs7mq3jfviqhszrc7l35ps9kg64-cmake-3.31.7/bin:/nix/store/1492q00cm64n0hs5966s8cqj6j0x5nxg-ninja-1.12.1/bin:/nix/store/h5khrpnjj3fb182sc32fx1z75w0lhksy-pkg-config-wrapper-0.29.2/bin:/nix/store/rzqvhv48m3nh8g3j4k6jmz6yqy8apr95-git-2.49.0/bin:/nix/store/nygfbkv0j6fvwwa82mdwxm4qfiq3p4q2-git-lfs-3.6.1/bin:/nix/store/fir4g1m8dvg46mh8silh3wnmm9mc0jix-htop-3.4.1/bin:/nix/store/9mc2m4sacbk4l7sc4w7m08m1x9bf5dgn-tmux-3.5a/bin:/nix/store/cxy72qdk41k3zjs5fw1nw1whv6wf7hv2-vim-9.1.1401/bin:/nix/store/74k8qwbfa6lm8psm2vjh2vj04fpr6c5g-openssl-3.4.1-bin/bin:/nix/store/m9k83ip1yx29xs94sa5x8j70s2vfgj6i-glib-2.84.2-dev/bin:/nix/store/zs5crhr67zp8cxn7dh4mwq08zw3sb31m-gettext-0.22.5/bin:/nix/store/rklrz4rwi03hxvz0kwh75vz55wb9b1qz-glib-2.84.2-bin/bin:/nix/store/xbpwk3xzanxj12157byj6wjagm2wfb3c-cuda-merged-12.8/bin:/nix/store/v0zrnzl3anb71ma5c2kx71dl8kyh0rf6-cuda_cuobjdump-12.8.90-bin/bin:/nix/store/v4mm21f67qki6ss6mqp3anlmaiw0r1zd-pre-commit-bin/bin:/nix/store/mq2i9br9h890bnahlds9jnff1jf6xjpb-python3.13-black-25.1.0/bin:/nix/store/sd81bvmch7njdpwx3lkjslixcbj5mivz-python3-3.13.4/bin:/nix/store/mdzm1l0rnpwp8ha0mbxll0db4r2p0xj3-python3.13-flake8-7.2.0/bin:/nix/store/xs72vlx7i6snrrrqx2zn529fbbqrwlwq-python3.13-pycodestyle-2.13.0/bin:/nix/store/5a8m3p0svp6myq1cz4ww431fsbh3xrg5-python3.13-pyflakes-3.3.2/bin:/nix/store/p6bch581drrxv3dm7vwxqazpbssjz4nv-python3.13-mypy-1.15.0/bin:/nix/store/1c8sm86wj45vwkb3ww2b870h9i9wna6r-patchelf-0.15.0/bin:/nix/store/psy9v2asypgl9ylg8cnzkixc7fv0snj0-coreutils-9.7/bin:/nix/store/c14zwgl8hf1wm0izij2i16xvk8ak70cy-findutils-4.10.0/bin:/nix/store/ibx4jfwlhjg4g0s6rrxrpaxa3ka8ns4m-diffutils-3.12/bin:/nix/store/pr318zsl44jdwpk9wk0sdrn19b6in7ah-gnused-4.9/bin:/nix/store/bc6zxzjnkjp4r9nhz5imy3cypvdh6r4n-gnugrep-3.12/bin:/nix/store/nv3y7zb1cwz1h9qy7nwz0s54j8dl1kqj-gawk-5.3.2/bin:/nix/store/lp82dcnrzljyix6yigwzrlpr1smvpmb0-gnutar-1.35/bin:/nix/store/6ag5dhk7sma61p6vl0kazfmpbrq08nqh-gzip-1.14/bin:/nix/store/ykdv4id6893gmkqwdmbimq237c1xqvq7-bzip2-1.0.8-bin/bin:/nix/store/6bwp1y45zlyvpr4ja2sk1yi9v5mrs94x-gnumake-4.4.1/bin:/nix/store/00zrahbb32nzawrmv9sjxn36h7qk9vrs-bash-5.2p37/bin:/nix/store/c9xmgszbf6i4dfq9r953khk9d7fdqigw-patch-2.8/bin:/nix/store/ikfwx7kbwz9zr7fziiac7f57jgbh3bnv-xz-5.8.1-bin/bin:/nix/store/3pdmbqy86wsbjdazxv1n3vrmj60vn0ri-file-5.45/bin:/run/wrappers/bin:/home/centra/.local/share/flatpak/exports/bin:/var/lib/flatpak/exports/bin:/home/centra/.nix-profile/bin:/nix/profile/bin:/home/centra/.local/state/nix/profile/bin:/etc/profiles/per-user/centra/bin:/nix/var/nix/profiles/default/bin:/run/current-system/sw/bin' +export VIRTUAL_ENV=/home/centra/dev/pnn/progressive-llm-training/.devenv/state/venv + diff --git a/.devenv/nix-eval-cache.db b/.devenv/nix-eval-cache.db new file mode 100644 index 0000000..7ee7c11 Binary files /dev/null and b/.devenv/nix-eval-cache.db differ diff --git a/.devenv/nix-eval-cache.db-shm b/.devenv/nix-eval-cache.db-shm new file mode 100644 index 0000000..206cf3b Binary files /dev/null and b/.devenv/nix-eval-cache.db-shm differ diff --git a/.devenv/nix-eval-cache.db-wal b/.devenv/nix-eval-cache.db-wal new file mode 100644 index 0000000..aea3837 Binary files /dev/null and b/.devenv/nix-eval-cache.db-wal differ diff --git a/.devenv/profile b/.devenv/profile new file mode 120000 index 0000000..0a7733c --- /dev/null +++ b/.devenv/profile @@ -0,0 +1 @@ +/nix/store/y2vscmx3lckyzyag6xg8b02pkdsk326d-devenv-profile \ No newline at end of file diff --git a/.devenv/run b/.devenv/run new file mode 120000 index 0000000..5d29d76 --- /dev/null +++ b/.devenv/run @@ -0,0 +1 @@ +/run/user/1000/devenv-adeda32 \ No newline at end of file diff --git a/.devenv/state/git-hooks/config.json b/.devenv/state/git-hooks/config.json new file mode 100644 index 0000000..be68384 --- /dev/null +++ b/.devenv/state/git-hooks/config.json @@ -0,0 +1 @@ +{configPath:.pre-commit-config.yaml} diff --git a/.devenv/tasks.db b/.devenv/tasks.db new file mode 100644 index 0000000..7ee7c11 Binary files /dev/null and b/.devenv/tasks.db differ diff --git a/.devenv/tasks.db-shm b/.devenv/tasks.db-shm new file mode 100644 index 0000000..f5895a2 Binary files /dev/null and b/.devenv/tasks.db-shm differ diff --git a/.devenv/tasks.db-wal b/.devenv/tasks.db-wal new file mode 100644 index 0000000..8411545 Binary files /dev/null and b/.devenv/tasks.db-wal differ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..20b8504 --- /dev/null +++ b/.gitignore @@ -0,0 +1,32 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +venv/ +ENV/ +env/ +.venv/ + +# Nix +result +result-* + +# Project specific +outputs/ +data/ +*.log +wandb/ +.ipynb_checkpoints/ +*.pt +*.pth +*.bin +*.safetensors + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ \ No newline at end of file diff --git a/=2.5.0 b/=2.5.0 new file mode 100644 index 0000000..0d0eafa --- /dev/null +++ b/=2.5.0 @@ -0,0 +1,33 @@ +Collecting flash-attn + Using cached flash_attn-2.8.0.post2-cp311-cp311-linux_x86_64.whl +Requirement already satisfied: torch in ./.devenv/state/venv/lib/python3.11/site-packages (from flash-attn) (2.7.1+cu128) +Collecting einops (from flash-attn) + Using cached einops-0.8.1-py3-none-any.whl.metadata (13 kB) +Requirement already satisfied: filelock in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (3.13.1) +Requirement already satisfied: typing-extensions>=4.10.0 in /nix/store/x74hdbjsz4ck98w8lyxv8kkwxs1wm2il-python3.13-typing-extensions-4.13.2/lib/python3.13/site-packages (from torch->flash-attn) (4.13.2) +Requirement already satisfied: sympy>=1.13.3 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (1.13.3) +Requirement already satisfied: networkx in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (3.3) +Requirement already satisfied: jinja2 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (3.1.4) +Requirement already satisfied: fsspec in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (2024.6.1) +Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.61 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.61) +Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.57 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.57) +Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.57 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.57) +Requirement already satisfied: nvidia-cudnn-cu12==9.7.1.26 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (9.7.1.26) +Requirement already satisfied: nvidia-cublas-cu12==12.8.3.14 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.3.14) +Requirement already satisfied: nvidia-cufft-cu12==11.3.3.41 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (11.3.3.41) +Requirement already satisfied: nvidia-curand-cu12==10.3.9.55 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (10.3.9.55) +Requirement already satisfied: nvidia-cusolver-cu12==11.7.2.55 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (11.7.2.55) +Requirement already satisfied: nvidia-cusparse-cu12==12.5.7.53 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.5.7.53) +Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (0.6.3) +Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (2.26.2) +Requirement already satisfied: nvidia-nvtx-cu12==12.8.55 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.55) +Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.61 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (12.8.61) +Requirement already satisfied: nvidia-cufile-cu12==1.13.0.11 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (1.13.0.11) +Requirement already satisfied: triton==3.3.1 in ./.devenv/state/venv/lib/python3.11/site-packages (from torch->flash-attn) (3.3.1) +Requirement already satisfied: setuptools>=40.8.0 in ./.devenv/state/venv/lib/python3.11/site-packages (from triton==3.3.1->torch->flash-attn) (80.9.0) +Requirement already satisfied: mpmath<1.4,>=1.1.0 in ./.devenv/state/venv/lib/python3.11/site-packages (from sympy>=1.13.3->torch->flash-attn) (1.3.0) +Requirement already satisfied: MarkupSafe>=2.0 in ./.devenv/state/venv/lib/python3.11/site-packages (from jinja2->torch->flash-attn) (2.1.5) +Using cached einops-0.8.1-py3-none-any.whl (64 kB) +Installing collected packages: einops, flash-attn + +Successfully installed einops-0.8.1 flash-attn-2.8.0.post2 diff --git a/LORA_TARGET_MODULES.md b/LORA_TARGET_MODULES.md new file mode 100644 index 0000000..39aa5c8 --- /dev/null +++ b/LORA_TARGET_MODULES.md @@ -0,0 +1,124 @@ +# LoRA Target Modules Reference + +This document provides the correct target module names for different model architectures when using LoRA (Low-Rank Adaptation). + +## Model Architecture Detection + +Use the inspection script to find correct target modules: + +```bash +# In the nix development environment +python /home/centra/dev/pnn/inspect_conv1d_model.py [model_name] +``` + +## Common Model Architectures + +### GPT-2 / DialoGPT Models +- **Model Type**: GPT2LMHeadModel +- **Layer Type**: Conv1D (not Linear!) +- **Base Model**: microsoft/DialoGPT-small, gpt2, gpt2-medium, gpt2-large, gpt2-xl + +#### Attention Modules +- `c_attn` - Combined query, key, value projection (nf=3*hidden_size) +- `c_proj` - Output projection + +#### MLP Modules +- `mlp.c_fc` - Feed-forward up projection +- `mlp.c_proj` - Feed-forward down projection + +#### Recommended Configurations +```yaml +# Basic stage (attention only) +target_modules: ["c_attn", "c_proj"] + +# Advanced stage (attention + MLP) +target_modules: ["c_attn", "c_proj", "mlp.c_fc", "mlp.c_proj"] +``` + +### LLaMA Models +- **Model Type**: LlamaForCausalLM +- **Layer Type**: Linear +- **Base Model**: meta-llama/Llama-2-7b-hf, meta-llama/Llama-3.2-8B + +#### Attention Modules +- `q_proj` - Query projection +- `k_proj` - Key projection +- `v_proj` - Value projection +- `o_proj` - Output projection + +#### MLP Modules +- `gate_proj` - Gate projection +- `up_proj` - Up projection +- `down_proj` - Down projection + +#### Recommended Configurations +```yaml +# Basic stage (attention only) +target_modules: ["q_proj", "v_proj"] + +# Advanced stage (attention + MLP) +target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] +``` + +### Mistral Models +- **Model Type**: MistralForCausalLM +- **Layer Type**: Linear +- **Base Model**: mistralai/Mistral-7B-v0.1 + +#### Target Modules (same as LLaMA) +```yaml +target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] +``` + +### Qwen Models +- **Model Type**: QWenLMHeadModel +- **Layer Type**: Linear +- **Base Model**: Qwen/Qwen-7B + +#### Target Modules +```yaml +target_modules: ["c_attn", "c_proj", "w1", "w2"] +``` + +## Important Notes + +1. **Conv1D vs Linear**: GPT-2 based models use `Conv1D` layers, not `Linear` layers +2. **Module Patterns**: Use simple patterns like `"c_attn"` rather than full paths like `"transformer.h.0.attn.c_attn"` +3. **Testing**: Always test your configuration before training by creating a PEFT model +4. **Architecture Variations**: Different model families use different naming conventions + +## Troubleshooting + +### Error: "Target module not found" +- Run the inspection script to find actual module names +- Check if the model uses Conv1D or Linear layers +- Verify the module naming pattern for your specific model + +### Error: "No trainable parameters" +- Ensure target modules exist in the model +- Check that the module names match exactly +- Verify the model architecture is supported by PEFT + +## Testing Your Configuration + +```python +from peft import get_peft_model, LoraConfig, TaskType + +# Test configuration +lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=8, + lora_alpha=16, + lora_dropout=0.1, + target_modules=["c_attn", "c_proj"], # Your target modules + bias="none" +) + +# Try to create PEFT model +try: + peft_model = get_peft_model(model, lora_config) + peft_model.print_trainable_parameters() + print("✓ Configuration works!") +except Exception as e: + print(f"✗ Configuration failed: {e}") +``` \ No newline at end of file diff --git a/config/README.md b/config/README.md new file mode 100644 index 0000000..797a4a0 --- /dev/null +++ b/config/README.md @@ -0,0 +1,85 @@ +# Training Configuration Files + +This directory contains configuration files for different model sizes and use cases. + +## Available Configurations + +### Small Models (Testing) +- `training_config.yaml` - Default configuration for small models (DialoGPT-small) + - Memory: ~1GB VRAM + - Batch size: 8 + - No quantization + +### Medium Models (8B) +- `training_config_large.yaml` - Configuration for 8B models (Llama-3.2-8B) + - Memory: ~12GB VRAM with 4-bit quantization + - Batch size: 1, gradient accumulation: 16-64 + - 4-bit quantization enabled + +### Large Models (13B) +- `training_config_13b.yaml` - Configuration for 13B models + - Memory: ~16GB VRAM with 4-bit quantization + - Batch size: 1, gradient accumulation: 32-128 + - Higher LoRA ranks (32-128) + +### Extra Large Models (70B) +- `training_config_70b.yaml` - Configuration for 70B models + - Memory: ~40GB+ VRAM with 4-bit quantization + - Batch size: 1, gradient accumulation: 64-256 + - Maximum LoRA ranks (64-256) + - Multi-GPU support with FSDP + +## Configuration Parameters + +### Model Settings +- `load_in_4bit`: Enable 4-bit quantization (recommended for large models) +- `gradient_checkpointing`: Trade compute for memory +- `use_flash_attention_2`: Faster attention computation if available + +### Adapter Settings +- `r`: LoRA rank (higher = more parameters but better capacity) +- `lora_alpha`: LoRA scaling factor (typically 2x the rank) +- `init_lora_weights`: Set to `true` for identity initialization + +### Training Settings +- `per_device_batch_size`: Usually 1 for large models +- `gradient_accumulation_steps`: Effective batch size multiplier +- `learning_rate`: Lower for larger models +- `bf16`: Use bfloat16 for better numerical stability + +## Usage + +```bash +# For 8B models +python scripts/train_progressive.py --config config/training_config_large.yaml + +# For 13B models +python scripts/train_progressive.py --config config/training_config_13b.yaml + +# For 70B models (requires multiple GPUs) +python scripts/train_progressive.py --config config/training_config_70b.yaml +``` + +## Memory Requirements + +| Model Size | VRAM (4-bit) | VRAM (16-bit) | GPUs Recommended | +|------------|--------------|---------------|------------------| +| 8B | 12-16GB | 32GB | 1x RTX 4090 | +| 13B | 16-20GB | 52GB | 1x A100 | +| 70B | 40-60GB | 140GB | 2x A100 | + +## Tips for Large Models + +1. **Start with smaller models** to validate your approach +2. **Use gradient checkpointing** to reduce memory usage +3. **Monitor GPU memory** during training +4. **Use lower learning rates** for stability +5. **Consider multi-GPU setup** for 70B+ models +6. **Enable flash attention** if available for speed + +## Troubleshooting + +- **OOM errors**: Reduce batch size or enable gradient checkpointing +- **Slow training**: Enable flash attention, use bf16 +- **Poor convergence**: Adjust learning rate or warmup steps +- **Multi-GPU issues**: Check FSDP configuration \ No newline at end of file diff --git a/config/training_config.yaml b/config/training_config.yaml new file mode 100644 index 0000000..ba68d02 --- /dev/null +++ b/config/training_config.yaml @@ -0,0 +1,36 @@ +experiment: + name: "progressive_reasoning_experiment" + base_model: "microsoft/DialoGPT-small" # Lightweight model for testing + output_dir: "./outputs" + use_wandb: false + wandb_project: "matsuo-llm-comp-2025" + +model: + load_in_4bit: false # Disable quantization for small model + bnb_4bit_compute_dtype: "bfloat16" + bnb_4bit_use_double_quant: true + device_map: "auto" + +progressive_stages: + - name: "basic_cot" + description: "Basic Chain-of-Thought reasoning" + dataset_path: "./data/basic_cot/" + adapter_config: + r: 8 + lora_alpha: 16 + lora_dropout: 0.1 + target_modules: ["c_attn", "c_proj"] + training: + num_epochs: 2 + per_device_batch_size: 8 # Increase batch size for small model + gradient_accumulation_steps: 2 # Reduce accumulation steps + learning_rate: 5e-4 # Higher learning rate for faster training + warmup_steps: 50 + max_length: 1024 # Shorter sequences + +evaluation: + benchmarks: + - "HLE" # Humanity's Last Exam + - "Do-Not-Answer" + save_results: true + results_dir: "./outputs/evaluation_results" \ No newline at end of file diff --git a/config/training_config_13b.yaml b/config/training_config_13b.yaml new file mode 100644 index 0000000..59fd626 --- /dev/null +++ b/config/training_config_13b.yaml @@ -0,0 +1,83 @@ +experiment: + name: "progressive_reasoning_13b" + base_model: "meta-llama/Llama-3.2-13B" # 13B model + output_dir: "./outputs" + use_wandb: true + wandb_project: "matsuo-llm-comp-2025" + +model: + load_in_4bit: true + bnb_4bit_compute_dtype: "bfloat16" + bnb_4bit_use_double_quant: true + bnb_4bit_quant_type: "nf4" + device_map: "auto" + gradient_checkpointing: true + use_flash_attention_2: true + +progressive_stages: + - name: "basic_cot" + description: "Basic Chain-of-Thought reasoning" + dataset_path: "./data/basic_cot/" + adapter_config: + r: 32 # Higher rank for 13B models + lora_alpha: 64 + lora_dropout: 0.05 + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj"] + init_lora_weights: true + training: + num_epochs: 1 + per_device_batch_size: 1 + gradient_accumulation_steps: 32 + learning_rate: 1e-4 + warmup_steps: 100 + max_length: 2048 + bf16: true + max_grad_norm: 0.3 + weight_decay: 0.001 + + - name: "math_reasoning" + description: "Mathematical reasoning with think tags" + dataset_path: "./data/math_reasoning/" + inherit_from: "basic_cot" + adapter_config: + r: 64 + lora_alpha: 128 + lora_dropout: 0.05 + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + init_lora_weights: true + training: + num_epochs: 2 + per_device_batch_size: 1 + gradient_accumulation_steps: 64 + learning_rate: 8e-5 + warmup_steps: 200 + max_length: 4096 + bf16: true + max_grad_norm: 0.3 + + - name: "complex_reasoning" + description: "Complex multi-step reasoning" + dataset_path: "./data/complex_reasoning/" + inherit_from: "math_reasoning" + adapter_config: + r: 128 # Maximum rank for 13B models + lora_alpha: 256 + lora_dropout: 0.05 + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + init_lora_weights: true + training: + num_epochs: 1 + per_device_batch_size: 1 + gradient_accumulation_steps: 128 + learning_rate: 5e-5 + warmup_steps: 300 + max_length: 8192 + bf16: true + max_grad_norm: 0.3 + +evaluation: + benchmarks: + - "HLE" + - "Do-Not-Answer" + save_results: true + results_dir: "./outputs/evaluation_results" \ No newline at end of file diff --git a/config/training_config_70b.yaml b/config/training_config_70b.yaml new file mode 100644 index 0000000..ed44f42 --- /dev/null +++ b/config/training_config_70b.yaml @@ -0,0 +1,101 @@ +experiment: + name: "progressive_reasoning_70b" + base_model: "meta-llama/Llama-3.2-70B" # 70B model - requires significant resources + output_dir: "./outputs" + use_wandb: true + wandb_project: "matsuo-llm-comp-2025" + +model: + load_in_4bit: true + bnb_4bit_compute_dtype: "bfloat16" + bnb_4bit_use_double_quant: true + bnb_4bit_quant_type: "nf4" + device_map: "auto" + gradient_checkpointing: true + use_flash_attention_2: true + +progressive_stages: + - name: "basic_cot" + description: "Basic Chain-of-Thought reasoning" + dataset_path: "./data/basic_cot/" + adapter_config: + r: 64 # Even higher rank for 70B models + lora_alpha: 128 + lora_dropout: 0.05 + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj"] + init_lora_weights: true + training: + num_epochs: 1 + per_device_batch_size: 1 + gradient_accumulation_steps: 64 + learning_rate: 5e-5 # Lower learning rate for stability + warmup_steps: 200 + max_length: 2048 + bf16: true + max_grad_norm: 0.3 + weight_decay: 0.001 + dataloader_num_workers: 2 + + - name: "math_reasoning" + description: "Mathematical reasoning with think tags" + dataset_path: "./data/math_reasoning/" + inherit_from: "basic_cot" + adapter_config: + r: 128 + lora_alpha: 256 + lora_dropout: 0.05 + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + init_lora_weights: true + training: + num_epochs: 1 + per_device_batch_size: 1 + gradient_accumulation_steps: 128 + learning_rate: 3e-5 + warmup_steps: 300 + max_length: 4096 + bf16: true + max_grad_norm: 0.3 + dataloader_num_workers: 2 + + - name: "complex_reasoning" + description: "Complex multi-step reasoning" + dataset_path: "./data/complex_reasoning/" + inherit_from: "math_reasoning" + adapter_config: + r: 256 # Maximum rank for 70B models + lora_alpha: 512 + lora_dropout: 0.05 + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + init_lora_weights: true + training: + num_epochs: 1 + per_device_batch_size: 1 + gradient_accumulation_steps: 256 + learning_rate: 2e-5 + warmup_steps: 500 + max_length: 8192 + bf16: true + max_grad_norm: 0.3 + dataloader_num_workers: 2 + +evaluation: + benchmarks: + - "HLE" + - "Do-Not-Answer" + save_results: true + results_dir: "./outputs/evaluation_results" + +# Additional settings for 70B models +optimization: + gradient_checkpointing: true + gradient_checkpointing_kwargs: + use_reentrant: false + ddp_find_unused_parameters: false + # Multi-GPU settings + fsdp: "full_shard auto_wrap" + fsdp_transformer_layer_cls_to_wrap: "LlamaDecoderLayer" + fsdp_min_num_params: 1000000 + fsdp_config: + min_num_params: 1000000 + sharding_strategy: "FULL_SHARD" + cpu_offload: false \ No newline at end of file diff --git a/config/training_config_gemma2_small.yaml b/config/training_config_gemma2_small.yaml new file mode 100644 index 0000000..fa035d6 --- /dev/null +++ b/config/training_config_gemma2_small.yaml @@ -0,0 +1,91 @@ +experiment: + name: "progressive_reasoning_gemma2_small" + base_model: "google/gemma-2-2b-it" # Instruction-tuned version + output_dir: "./outputs" + use_wandb: true + wandb_project: "matsuo-llm-comp-2025" + +model: + load_in_4bit: false # 2B model is manageable without quantization + bnb_4bit_compute_dtype: "bfloat16" + bnb_4bit_use_double_quant: true + device_map: "auto" + gradient_checkpointing: false + use_flash_attention_2: false + use_eager_attention: true # Required for Gemma 3 models + +progressive_stages: + - name: "basic_cot" + description: "Basic Chain-of-Thought reasoning" + dataset_path: "./data/basic_cot/" + adapter_config: + r: 8 # Start with smaller rank for small model + lora_alpha: 16 + lora_dropout: 0.1 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] + init_lora_weights: true + training: + num_epochs: 3 + per_device_batch_size: 8 # Larger batch size for small model + gradient_accumulation_steps: 2 + learning_rate: 5e-4 # Higher learning rate for small model + warmup_steps: 50 + max_length: 1024 + bf16: true + max_grad_norm: 1.0 + weight_decay: 0.001 + save_steps: 50 + logging_steps: 10 + + - name: "math_reasoning" + description: "Mathematical reasoning with think tags" + dataset_path: "./data/math_reasoning/" + inherit_from: "basic_cot" + adapter_config: + r: 16 + lora_alpha: 32 + lora_dropout: 0.1 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + init_lora_weights: true + training: + num_epochs: 3 + per_device_batch_size: 4 + gradient_accumulation_steps: 4 + learning_rate: 3e-4 + warmup_steps: 100 + max_length: 2048 + bf16: true + max_grad_norm: 1.0 + + - name: "complex_reasoning" + description: "Complex multi-step reasoning with Mixture-of-Thoughts" + dataset_path: "open-r1/Mixture-of-Thoughts" # HuggingFace dataset + inherit_from: "math_reasoning" + adapter_config: + r: 32 + lora_alpha: 64 + lora_dropout: 0.1 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + init_lora_weights: true + training: + num_epochs: 1 # Large dataset, fewer epochs + per_device_batch_size: 2 + gradient_accumulation_steps: 8 + learning_rate: 2e-4 + warmup_steps: 200 + max_length: 4096 + bf16: true + max_grad_norm: 1.0 + save_steps: 500 + logging_steps: 50 + dataset_config: + streaming: true + max_samples: 30000 + split: "train" + +evaluation: + benchmarks: + - "HLE" + - "Do-Not-Answer" + save_results: true + results_dir: "./outputs/evaluation_results" \ No newline at end of file diff --git a/config/training_config_gemma3_1b.yaml b/config/training_config_gemma3_1b.yaml new file mode 100644 index 0000000..2433612 --- /dev/null +++ b/config/training_config_gemma3_1b.yaml @@ -0,0 +1,102 @@ +experiment: + name: "progressive_reasoning_gemma3_1b" + base_model: "google/gemma-3-1b-pt" # Using Gemma 2 2B (1B might not be available) + output_dir: "./outputs" + use_wandb: true + wandb_project: "matsuo-llm-comp-2025" + +model: + load_in_4bit: false + bnb_4bit_compute_dtype: "bfloat16" + bnb_4bit_use_double_quant: true + device_map: "auto" + gradient_checkpointing: false # Not needed for small models + use_flash_attention_2: false + use_eager_attention: true + +progressive_stages: + - name: "basic_cot" + description: "Basic Chain-of-Thought reasoning" + dataset_path: "./data/basic_cot/" + adapter_config: + r: 8 + lora_alpha: 16 + lora_dropout: 0.1 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] # Gemma attention modules + init_lora_weights: true + training: + num_epochs: 2 + per_device_batch_size: 8 + gradient_accumulation_steps: 2 + learning_rate: 5e-4 + warmup_steps: 50 + max_length: 1024 + fp16: false + bf16: true + max_grad_norm: 1.0 + weight_decay: 0.001 + save_steps: 100 + logging_steps: 10 + + - name: "math_reasoning" + description: "Mathematical reasoning with OpenR1-Math-220k dataset" + dataset_path: "open-r1/OpenR1-Math-220k" # HuggingFace dataset + inherit_from: "basic_cot" + adapter_config: + r: 16 + lora_alpha: 32 + lora_dropout: 0.1 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + init_lora_weights: true + training: + num_epochs: 1 # Large dataset, fewer epochs + per_device_batch_size: 4 + gradient_accumulation_steps: 4 + learning_rate: 3e-4 + warmup_steps: 100 + max_length: 2048 + bf16: true + max_grad_norm: 1.0 + weight_decay: 0.001 + save_steps: 1000 + logging_steps: 100 + dataset_config: + # OpenR1-Math-220k specific settings + streaming: true # Use streaming for large dataset + max_samples: 200000 # Limit samples for faster training + split: "train" + + - name: "complex_reasoning" + description: "Complex multi-step reasoning with Mixture-of-Thoughts" + dataset_path: "open-r1/Mixture-of-Thoughts" # HuggingFace dataset + inherit_from: "math_reasoning" + adapter_config: + r: 32 + lora_alpha: 64 + lora_dropout: 0.1 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + init_lora_weights: true + training: + num_epochs: 1 # Large dataset, fewer epochs + per_device_batch_size: 2 + gradient_accumulation_steps: 8 + learning_rate: 2e-4 + warmup_steps: 200 + max_length: 4096 + bf16: true + max_grad_norm: 1.0 + weight_decay: 0.001 + save_steps: 500 + logging_steps: 50 + dataset_config: + # Mixture-of-Thoughts specific settings + streaming: true # Use streaming for large dataset + max_samples: 30000 # Limit samples for faster training + split: "train" + +evaluation: + benchmarks: + - "HLE" + - "Do-Not-Answer" + save_results: true + results_dir: "./outputs/evaluation_results" diff --git a/config/training_config_gemma3_1b_cpu_offload.yaml b/config/training_config_gemma3_1b_cpu_offload.yaml new file mode 100644 index 0000000..2b4158c --- /dev/null +++ b/config/training_config_gemma3_1b_cpu_offload.yaml @@ -0,0 +1,133 @@ +experiment: + name: "progressive_reasoning_gemma3_1b_cpu_offload" + base_model: "google/gemma-3-1b-pt" # Using Gemma 3 1B + output_dir: "./outputs" + use_wandb: true + wandb_project: "matsuo-llm-comp-2025" + +model: + load_in_4bit: true # Enable 4-bit quantization for QLoRA + bnb_4bit_compute_dtype: "bfloat16" + bnb_4bit_use_double_quant: true + bnb_4bit_quant_type: "nf4" + device_map: "auto" # Let accelerate handle device placement + max_memory: + 0: "5GB" # Limit GPU memory to 3GB (leave room for CUDA kernels) + "cpu": "32GB" # Allow up to 32GB CPU RAM + offload_folder: "./offload" # Directory for disk offloading if needed + gradient_checkpointing: true # Trade compute for memory + use_flash_attention_2: false + use_eager_attention: true + +progressive_stages: + - name: "basic_cot" + description: "Basic Chain-of-Thought reasoning" + dataset_path: "./data/basic_cot/" + adapter_config: + r: 8 # Lower rank for memory efficiency + lora_alpha: 16 + lora_dropout: 0.1 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] + init_lora_weights: true + training: + num_epochs: 2 + per_device_batch_size: 2 # Smaller batch size + gradient_accumulation_steps: 8 # Compensate with gradient accumulation + learning_rate: 5e-4 + warmup_steps: 50 + max_length: 512 # Shorter sequences for memory + bf16: true + max_grad_norm: 1.0 + weight_decay: 0.001 + save_steps: 100 + logging_steps: 10 + dataloader_num_workers: 0 # Disable multiprocessing to save memory + optim: "paged_adamw_8bit" # Use 8-bit optimizer + + - name: "math_reasoning" + description: "Mathematical reasoning with OpenR1-Math-220k dataset" + dataset_path: "open-r1/OpenR1-Math-220k" + inherit_from: "basic_cot" + adapter_config: + r: 16 + lora_alpha: 32 + lora_dropout: 0.1 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + init_lora_weights: true + training: + num_epochs: 1 + per_device_batch_size: 1 # Minimal batch size + gradient_accumulation_steps: 16 + learning_rate: 3e-4 + warmup_steps: 100 + max_length: 1024 + bf16: true + max_grad_norm: 1.0 + weight_decay: 0.001 + save_steps: 1000 + logging_steps: 100 + optim: "paged_adamw_8bit" + dataset_config: + streaming: true + max_samples: 200000 # Reduced for testing + split: "train" + + - name: "complex_reasoning" + description: "Complex multi-step reasoning with Mixture-of-Thoughts" + dataset_path: "open-r1/Mixture-of-Thoughts" # HuggingFace dataset + inherit_from: "math_reasoning" + adapter_config: + r: 32 + lora_alpha: 64 + lora_dropout: 0.1 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + init_lora_weights: true + training: + num_epochs: 1 + per_device_batch_size: 1 + gradient_accumulation_steps: 32 + learning_rate: 2e-4 + warmup_steps: 200 + max_length: 2048 + bf16: true + max_grad_norm: 1.0 + weight_decay: 0.001 + optim: "paged_adamw_8bit" + save_steps: 500 + logging_steps: 50 + dataset_config: + streaming: true + max_samples: 300000 # Limited for CPU offload config + split: "train" + +evaluation: + benchmarks: + - "HLE" + - "Do-Not-Answer" + save_results: true + results_dir: "./outputs/evaluation_results" + +# DeepSpeed configuration for advanced CPU offloading (optional) +# Uncomment to use DeepSpeed ZeRO-2 with CPU offload +# deepspeed: +# zero_optimization: +# stage: 2 +# offload_optimizer: +# device: "cpu" +# pin_memory: true +# offload_param: +# device: "cpu" +# pin_memory: true +# overlap_comm: true +# contiguous_gradients: true +# sub_group_size: 1e9 +# reduce_bucket_size: 1e6 + +# FSDP configuration for distributed training (optional) +# Uncomment to use FSDP with CPU offload +# fsdp: +# sharding_strategy: "FULL_SHARD" +# cpu_offload: true +# auto_wrap_policy: "TRANSFORMER_BASED_WRAP" +# transformer_layer_cls_to_wrap: "GemmaDecoderLayer" +# min_num_params: 1e6 diff --git a/config/training_config_large.yaml b/config/training_config_large.yaml new file mode 100644 index 0000000..22cbd3b --- /dev/null +++ b/config/training_config_large.yaml @@ -0,0 +1,98 @@ +experiment: + name: "progressive_reasoning_large_model" + base_model: "meta-llama/Llama-3.2-8B" # Or other whitelisted models + output_dir: "./outputs" + use_wandb: true + wandb_project: "matsuo-llm-comp-2025" + +model: + load_in_4bit: true # Enable 4-bit quantization for memory efficiency + bnb_4bit_compute_dtype: "bfloat16" + bnb_4bit_use_double_quant: true + bnb_4bit_quant_type: "nf4" + device_map: "auto" + # Additional memory optimizations + gradient_checkpointing: true + use_flash_attention_2: true # If available + +progressive_stages: + - name: "basic_cot" + description: "Basic Chain-of-Thought reasoning" + dataset_path: "./data/basic_cot/" + adapter_config: + r: 16 # Larger rank for bigger models + lora_alpha: 32 + lora_dropout: 0.05 + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj"] + init_lora_weights: true # Identity initialization + training: + num_epochs: 1 + per_device_batch_size: 1 # Small batch size for large models + gradient_accumulation_steps: 16 # Effective batch size = 16 + learning_rate: 2e-4 + warmup_steps: 100 + max_length: 2048 + fp16: false + bf16: true + max_grad_norm: 0.3 + weight_decay: 0.001 + save_steps: 50 + logging_steps: 10 + + - name: "math_reasoning" + description: "Mathematical reasoning with think tags" + dataset_path: "./data/math_reasoning/" + inherit_from: "basic_cot" + adapter_config: + r: 32 # Increase rank for more complex tasks + lora_alpha: 64 + lora_dropout: 0.05 + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + init_lora_weights: true + training: + num_epochs: 2 + per_device_batch_size: 1 + gradient_accumulation_steps: 32 # Effective batch size = 32 + learning_rate: 1e-4 + warmup_steps: 200 + max_length: 4096 + bf16: true + max_grad_norm: 0.3 + weight_decay: 0.001 + + - name: "complex_reasoning" + description: "Complex multi-step reasoning" + dataset_path: "./data/complex_reasoning/" + inherit_from: "math_reasoning" + adapter_config: + r: 64 # Maximum rank for most complex tasks + lora_alpha: 128 + lora_dropout: 0.05 + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + init_lora_weights: true + training: + num_epochs: 2 + per_device_batch_size: 1 + gradient_accumulation_steps: 64 # Effective batch size = 64 + learning_rate: 5e-5 + warmup_steps: 300 + max_length: 8192 + bf16: true + max_grad_norm: 0.3 + weight_decay: 0.001 + +evaluation: + benchmarks: + - "HLE" + - "Do-Not-Answer" + save_results: true + results_dir: "./outputs/evaluation_results" + +# Memory optimization settings +optimization: + gradient_checkpointing: true + gradient_checkpointing_kwargs: + use_reentrant: false + ddp_find_unused_parameters: false + fsdp: "full_shard auto_wrap" # For multi-GPU setups + fsdp_transformer_layer_cls_to_wrap: "LlamaDecoderLayer" \ No newline at end of file diff --git a/config/training_config_llama_auth.yaml b/config/training_config_llama_auth.yaml new file mode 100644 index 0000000..090df6f --- /dev/null +++ b/config/training_config_llama_auth.yaml @@ -0,0 +1,85 @@ +experiment: + name: "progressive_reasoning_llama_auth" + base_model: "meta-llama/Llama-3.2-8B" + output_dir: "./outputs" + use_wandb: true + wandb_project: "matsuo-llm-comp-2025" + +model: + load_in_4bit: true + bnb_4bit_compute_dtype: "bfloat16" + bnb_4bit_use_double_quant: true + bnb_4bit_quant_type: "nf4" + device_map: "auto" + gradient_checkpointing: true + use_flash_attention_2: true + # Add your HuggingFace token here, or set HF_TOKEN environment variable + # hf_token: "your_token_here" + +progressive_stages: + - name: "basic_cot" + description: "Basic Chain-of-Thought reasoning" + dataset_path: "./data/basic_cot/" + adapter_config: + r: 16 + lora_alpha: 32 + lora_dropout: 0.05 + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj"] + init_lora_weights: true + training: + num_epochs: 1 + per_device_batch_size: 1 + gradient_accumulation_steps: 16 + learning_rate: 2e-4 + warmup_steps: 100 + max_length: 2048 + bf16: true + max_grad_norm: 0.3 + weight_decay: 0.001 + + - name: "math_reasoning" + description: "Mathematical reasoning with think tags" + dataset_path: "./data/math_reasoning/" + inherit_from: "basic_cot" + adapter_config: + r: 32 + lora_alpha: 64 + lora_dropout: 0.05 + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + init_lora_weights: true + training: + num_epochs: 2 + per_device_batch_size: 1 + gradient_accumulation_steps: 32 + learning_rate: 1e-4 + warmup_steps: 200 + max_length: 4096 + bf16: true + max_grad_norm: 0.3 + + - name: "complex_reasoning" + description: "Complex multi-step reasoning" + dataset_path: "./data/complex_reasoning/" + inherit_from: "math_reasoning" + adapter_config: + r: 64 + lora_alpha: 128 + lora_dropout: 0.05 + target_modules: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + init_lora_weights: true + training: + num_epochs: 2 + per_device_batch_size: 1 + gradient_accumulation_steps: 64 + learning_rate: 5e-5 + warmup_steps: 300 + max_length: 8192 + bf16: true + max_grad_norm: 0.3 + +evaluation: + benchmarks: + - "HLE" + - "Do-Not-Answer" + save_results: true + results_dir: "./outputs/evaluation_results" \ No newline at end of file diff --git a/config/training_config_public.yaml b/config/training_config_public.yaml new file mode 100644 index 0000000..a1abea4 --- /dev/null +++ b/config/training_config_public.yaml @@ -0,0 +1,82 @@ +experiment: + name: "progressive_reasoning_public_model" + base_model: "microsoft/DialoGPT-medium" # Public model, no authentication needed + output_dir: "./outputs" + use_wandb: false + wandb_project: "matsuo-llm-comp-2025" + +model: + load_in_4bit: false # DialoGPT is smaller, quantization not needed + bnb_4bit_compute_dtype: "bfloat16" + bnb_4bit_use_double_quant: true + device_map: "auto" + gradient_checkpointing: false + +progressive_stages: + - name: "basic_cot" + description: "Basic Chain-of-Thought reasoning" + dataset_path: "./data/basic_cot/" + adapter_config: + r: 16 + lora_alpha: 32 + lora_dropout: 0.1 + target_modules: ["c_attn", "c_proj"] # GPT-2 style attention modules + init_lora_weights: true + training: + num_epochs: 2 + per_device_batch_size: 4 + gradient_accumulation_steps: 4 + learning_rate: 2e-4 + warmup_steps: 100 + max_length: 1024 + fp16: false + bf16: false # Use fp32 for smaller models + max_grad_norm: 1.0 + weight_decay: 0.001 + + - name: "math_reasoning" + description: "Mathematical reasoning with think tags" + dataset_path: "./data/math_reasoning/" + inherit_from: "basic_cot" + adapter_config: + r: 32 + lora_alpha: 64 + lora_dropout: 0.1 + target_modules: ["c_attn", "c_proj"] + init_lora_weights: true + training: + num_epochs: 3 + per_device_batch_size: 2 + gradient_accumulation_steps: 8 + learning_rate: 1e-4 + warmup_steps: 200 + max_length: 2048 + bf16: false + max_grad_norm: 1.0 + + - name: "complex_reasoning" + description: "Complex multi-step reasoning" + dataset_path: "./data/complex_reasoning/" + inherit_from: "math_reasoning" + adapter_config: + r: 64 + lora_alpha: 128 + lora_dropout: 0.1 + target_modules: ["c_attn", "c_proj"] + init_lora_weights: true + training: + num_epochs: 2 + per_device_batch_size: 1 + gradient_accumulation_steps: 16 + learning_rate: 5e-5 + warmup_steps: 300 + max_length: 4096 + bf16: false + max_grad_norm: 1.0 + +evaluation: + benchmarks: + - "HLE" + - "Do-Not-Answer" + save_results: true + results_dir: "./outputs/evaluation_results" \ No newline at end of file diff --git a/devenv.lock b/devenv.lock new file mode 100644 index 0000000..d06c441 --- /dev/null +++ b/devenv.lock @@ -0,0 +1,139 @@ +{ + "nodes": { + "devenv": { + "locked": { + "dir": "src/modules", + "lastModified": 1751909516, + "owner": "cachix", + "repo": "devenv", + "rev": "36e4cf7d6cb89862e69efce4e5c147ac2e4d38f9", + "type": "github" + }, + "original": { + "dir": "src/modules", + "owner": "cachix", + "repo": "devenv", + "type": "github" + } + }, + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1747046372, + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "flake": false, + "locked": { + "lastModified": 1747046372, + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "git-hooks": { + "inputs": { + "flake-compat": "flake-compat", + "gitignore": "gitignore", + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1750779888, + "owner": "cachix", + "repo": "git-hooks.nix", + "rev": "16ec914f6fb6f599ce988427d9d94efddf25fe6d", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "git-hooks.nix", + "type": "github" + } + }, + "gitignore": { + "inputs": { + "nixpkgs": [ + "git-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1709087332, + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1751792365, + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "1fd8bada0b6117e6c7eb54aad5813023eed37ccb", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs-python": { + "inputs": { + "flake-compat": "flake-compat_2", + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1749760516, + "owner": "cachix", + "repo": "nixpkgs-python", + "rev": "908dbb466af5955ea479ac95953333fd64387216", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "nixpkgs-python", + "type": "github" + } + }, + "root": { + "inputs": { + "devenv": "devenv", + "git-hooks": "git-hooks", + "nixpkgs": "nixpkgs", + "nixpkgs-python": "nixpkgs-python", + "pre-commit-hooks": [ + "git-hooks" + ] + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake-minimal.nix b/flake-minimal.nix new file mode 100644 index 0000000..d5fde54 --- /dev/null +++ b/flake-minimal.nix @@ -0,0 +1,95 @@ +{ + description = "Progressive LLM Training for 松尾研LLMコンペ2025 (Minimal)"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, flake-utils }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = import nixpkgs { + inherit system; + config = { + allowUnfree = true; + cudaSupport = true; + }; + }; + + # Python 3.11 for better compatibility + python = pkgs.python311; + + # Minimal Python packages + pythonWithPackages = python.withPackages (ps: with ps; [ + # Core essentials only + torch + transformers + numpy + + # Essential dependencies + pyyaml + + # Build tools + pip + setuptools + wheel + ]); + + in + { + devShells.default = pkgs.mkShell { + buildInputs = with pkgs; [ + # Python with packages + pythonWithPackages + + # Build tools + gcc + cmake + ninja + pkg-config + + # Git + git + git-lfs + + # Libraries needed for Python packages + openssl + zlib + glib + stdenv.cc.cc.lib + + # CUDA support + cudaPackages.cudatoolkit + cudaPackages.cudnn + ]; + + shellHook = '' + echo "🚀 Progressive LLM Training Environment (Minimal)" + echo "Python version: $(python --version)" + echo "PyTorch version: $(python -c 'import torch; print(torch.__version__)')" + echo "CUDA available: $(python -c 'import torch; print(torch.cuda.is_available())')" + + # Set up CUDA environment + export CUDA_HOME=${pkgs.cudaPackages.cudatoolkit} + export CUDA_PATH=${pkgs.cudaPackages.cudatoolkit} + export LD_LIBRARY_PATH=${pkgs.cudaPackages.cudatoolkit}/lib:${pkgs.cudaPackages.cudnn}/lib:${pkgs.stdenv.cc.cc.lib}/lib:$LD_LIBRARY_PATH + + # Set Python path + export PYTHONPATH=$PWD/src:$PYTHONPATH + + echo "" + echo "Note: This is a minimal configuration. Install additional packages with pip as needed:" + echo " pip install accelerate peft trl datasets bitsandbytes wandb jsonlines scikit-learn sentencepiece protobuf" + echo " pip install flash-attn --no-build-isolation" + ''; + + # Environment variables + CUDA_HOME = "${pkgs.cudaPackages.cudatoolkit}"; + CUDA_PATH = "${pkgs.cudaPackages.cudatoolkit}"; + NIX_SHELL_PRESERVE_PROMPT = 1; + LOCALE_ARCHIVE = "${pkgs.glibcLocales}/lib/locale/locale-archive"; + LC_ALL = "en_US.UTF-8"; + }; + }); +} \ No newline at end of file diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..bd80c39 --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1751792365, + "narHash": "sha256-J1kI6oAj25IG4EdVlg2hQz8NZTBNYvIS0l4wpr9KcUo=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "1fd8bada0b6117e6c7eb54aad5813023eed37ccb", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..d0d9d05 --- /dev/null +++ b/flake.nix @@ -0,0 +1,195 @@ +{ + description = "Progressive LLM Training for 松尾研LLMコンペ2025"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, flake-utils }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = import nixpkgs { + inherit system; + config = { + allowUnfree = true; + cudaSupport = true; + }; + overlays = [ + (final: prev: { + python311 = prev.python311.override { + packageOverrides = python-self: python-super: { + # Disable tests for problematic packages + pytest-doctestplus = python-super.pytest-doctestplus.overrideAttrs (oldAttrs: { + doCheck = false; + doInstallCheck = false; + pytestCheckPhase = "echo 'Skipping tests'"; + }); + # Also disable tests for jupyter-related packages if they cause issues + jupyter = python-super.jupyter.overrideAttrs (oldAttrs: { + doCheck = false; + doInstallCheck = false; + }); + notebook = python-super.notebook.overrideAttrs (oldAttrs: { + doCheck = false; + doInstallCheck = false; + }); + # Disable tests for psycopg and psycopg2 + psycopg = python-super.psycopg.overrideAttrs (oldAttrs: { + doCheck = false; + doInstallCheck = false; + pytestCheckPhase = "echo 'Skipping tests'"; + pythonImportsCheck = []; # Disable import checks + }); + psycopg2 = python-super.psycopg2.overrideAttrs (oldAttrs: { + doCheck = false; + doInstallCheck = false; + pytestCheckPhase = "echo 'Skipping tests'"; + pythonImportsCheck = []; # Disable import checks + }); + # Disable tests for sqlframe + sqlframe = python-super.sqlframe.overrideAttrs (oldAttrs: { + doCheck = false; + doInstallCheck = false; + pytestCheckPhase = "echo 'Skipping tests'"; + pythonImportsCheck = []; # Disable import checks + }); + # Disable tests for accelerate + accelerate = python-super.accelerate.overrideAttrs (oldAttrs: { + doCheck = false; + doInstallCheck = false; + pytestCheckPhase = "echo 'Skipping tests'"; + pythonImportsCheck = []; # Disable import checks + }); + }; + }; + }) + ]; + }; + + # Python 3.11 for better compatibility + python = pkgs.python311; + + # Python packages + pythonWithPackages = python.withPackages (ps: with ps; [ + # Core ML packages + torch + torchvision + torchaudio + transformers + accelerate + datasets + tokenizers + scikit-learn + + # Required dependencies from requirements.txt + pyyaml + jsonlines + sentencepiece + protobuf + + # Additional useful packages + numpy + scipy + matplotlib + jupyter + notebook + ipython + pandas + rich # For TUI + + # Development tools + black + flake8 + pytest + mypy + + # Build tools + pip + setuptools + wheel + + # LLM specific packages + peft + trl + bitsandbytes + wandb + ]); + + in + { + devShells.default = pkgs.mkShell { + buildInputs = with pkgs; [ + # Python with packages + pythonWithPackages + + # Build tools + gcc + cmake + ninja + pkg-config + + # Git + git + git-lfs + + # Development tools + htop + tmux + vim + + # Libraries needed for Python packages + openssl + zlib + glib + stdenv.cc.cc.lib + + # CUDA support + cudaPackages.cudatoolkit + cudaPackages.cudnn + ]; + + shellHook = '' + echo "🚀 Progressive LLM Training Environment" + echo "Python version: $(python --version)" + echo "PyTorch version: $(python -c 'import torch; print(torch.__version__)')" + echo "CUDA available: $(python -c 'import torch; print(torch.cuda.is_available())')" + + # Set up CUDA environment + export CUDA_HOME=${pkgs.cudaPackages.cudatoolkit} + export CUDA_PATH=${pkgs.cudaPackages.cudatoolkit} + export LD_LIBRARY_PATH=${pkgs.cudaPackages.cudatoolkit}/lib:${pkgs.cudaPackages.cudnn}/lib:${pkgs.stdenv.cc.cc.lib}/lib:$LD_LIBRARY_PATH + + # Set Python path + export PYTHONPATH=$PWD/src:$PYTHONPATH + + echo "" + echo "Available commands:" + echo " python scripts/train_progressive.py # Start training" + echo " python scripts/evaluate.py # Evaluate model" + echo " jupyter notebook # Start Jupyter" + echo "" + + # Create data directory if not exists + mkdir -p data + + # Prepare sample data if not exists + if [ ! -f "data/basic_cot/train.jsonl" ]; then + echo "Preparing sample datasets..." + python -c "from src.data_utils import prepare_sample_datasets; prepare_sample_datasets()" || echo "Sample data preparation skipped" + fi + + # Note about flash-attn + echo "Note: flash-attn is not included in nixpkgs. If needed, install manually with:" + echo " pip install flash-attn --no-build-isolation" + ''; + + # Environment variables + CUDA_HOME = "${pkgs.cudaPackages.cudatoolkit}"; + CUDA_PATH = "${pkgs.cudaPackages.cudatoolkit}"; + NIX_SHELL_PRESERVE_PROMPT = 1; + LOCALE_ARCHIVE = "${pkgs.glibcLocales}/lib/locale/locale-archive"; + LC_ALL = "en_US.UTF-8"; + }; + }); +} \ No newline at end of file diff --git a/requirements-cpu.txt b/requirements-cpu.txt new file mode 100644 index 0000000..6e3167e --- /dev/null +++ b/requirements-cpu.txt @@ -0,0 +1,15 @@ +# CPU version of PyTorch +torch>=2.0.0 --index-url https://download.pytorch.org/whl/cpu +transformers>=4.40.0 +accelerate>=0.27.0 +peft>=0.11.0 +trl>=0.9.0 +datasets>=2.18.0 +bitsandbytes>=0.43.0 +wandb>=0.16.0 +pyyaml>=6.0 +jsonlines>=4.0.0 +scikit-learn>=1.3.0 +# flash-attn is not needed for CPU version +sentencepiece>=0.2.0 +protobuf>=4.25.0 \ No newline at end of file diff --git a/requirements-torch.txt b/requirements-torch.txt new file mode 100644 index 0000000..7a1beba --- /dev/null +++ b/requirements-torch.txt @@ -0,0 +1,3 @@ +--index-url https://download.pytorch.org/whl/cu128 +torch>=2.0.0 +torchaudio>=2.0.0 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..534ab7a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +transformers>=4.40.0 +accelerate>=0.27.0 +peft>=0.11.0 +trl>=0.9.0 +datasets>=2.18.0 +bitsandbytes>=0.43.0 +wandb>=0.16.0 +pyyaml>=6.0 +jsonlines>=4.0.0 +scikit-learn>=1.3.0 +# flash-attn>=2.5.0 # Install separately with --no-build-isolation +sentencepiece>=0.2.0 +protobuf>=4.25.0 diff --git a/scripts/analyze_adapter_size.py b/scripts/analyze_adapter_size.py new file mode 100755 index 0000000..b05b0d3 --- /dev/null +++ b/scripts/analyze_adapter_size.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python3 +""" +Analyze the size and structure of LoRA adapters +""" + +import sys +from pathlib import Path +import torch +import yaml +from peft import PeftModel, LoraConfig + +# Add src to path +sys.path.append(str(Path(__file__).parent.parent)) + +from src.progressive_model import ProgressiveReasoningModel + + +def analyze_adapter_sizes(): + # Load configuration + with open("config/training_config.yaml") as f: + config = yaml.safe_load(f) + + print("=" * 60) + print("LoRA Adapter Size Analysis") + print("=" * 60) + + # Get adapter configuration from config + basic_cot_config = config["progressive_stages"][0] + adapter_config = basic_cot_config["adapter_config"] + + print(f"\nConfiguration for 'basic_cot' adapter:") + print(f" - r (rank): {adapter_config['r']}") + print(f" - lora_alpha: {adapter_config['lora_alpha']}") + print(f" - lora_dropout: {adapter_config['lora_dropout']}") + print(f" - target_modules: {adapter_config['target_modules']}") + + # Load the base model to get dimensions + print("\nLoading base model to analyze dimensions...") + model_wrapper = ProgressiveReasoningModel(config) + model_wrapper.setup_base_model() + + # Analyze model architecture + print(f"\nBase model: {config['experiment']['base_model']}") + + # Count parameters in base model + total_params = sum(p.numel() for p in model_wrapper.model.parameters()) + print(f"Total base model parameters: {total_params:,}") + + # Load saved adapter if it exists + adapter_path = Path(config["experiment"]["output_dir"]) / "adapters" / "basic_cot" + if adapter_path.exists(): + print(f"\nLoading saved adapter from: {adapter_path}") + + # Load adapter state dict + adapter_model_path = adapter_path / "adapter_model.safetensors" + if not adapter_model_path.exists(): + adapter_model_path = adapter_path / "adapter_model.bin" + + if adapter_model_path.exists(): + if adapter_model_path.suffix == ".safetensors": + from safetensors.torch import load_file + adapter_weights = load_file(adapter_model_path) + else: + adapter_weights = torch.load(adapter_model_path, map_location="cpu") + + print("\nLoRA Adapter Layer Details:") + print("-" * 60) + + total_lora_params = 0 + layer_info = {} + + for name, tensor in adapter_weights.items(): + size = tensor.numel() + total_lora_params += size + + # Parse layer name + parts = name.split('.') + if 'lora_A' in name or 'lora_B' in name: + # Extract module info + module_name = '.'.join(parts[:-2]) + lora_type = parts[-2] # lora_A or lora_B + + if module_name not in layer_info: + layer_info[module_name] = {} + + layer_info[module_name][lora_type] = { + 'shape': list(tensor.shape), + 'params': size + } + + # Display layer information + for module, info in sorted(layer_info.items()): + print(f"\nModule: {module}") + if 'lora_A' in info and 'lora_B' in info: + shape_a = info['lora_A']['shape'] + shape_b = info['lora_B']['shape'] + params_a = info['lora_A']['params'] + params_b = info['lora_B']['params'] + + print(f" LoRA A: {shape_a} = {params_a:,} parameters") + print(f" LoRA B: {shape_b} = {params_b:,} parameters") + print(f" Total: {params_a + params_b:,} parameters") + + # Calculate original layer size (approximation) + original_size = shape_a[1] * shape_b[0] + compression_ratio = original_size / (params_a + params_b) + print(f" Original layer size (approx): {original_size:,} parameters") + print(f" Compression ratio: {compression_ratio:.1f}x") + + print("\n" + "=" * 60) + print(f"Total LoRA parameters: {total_lora_params:,}") + print(f"Percentage of base model: {(total_lora_params / total_params) * 100:.2f}%") + + # Calculate theoretical size + r = adapter_config['r'] + num_modules = len(adapter_config['target_modules']) + + # For GPT models, typical dimensions + if "DialoGPT" in config['experiment']['base_model']: + hidden_size = 768 # DialoGPT-small uses 768 + print(f"\nTheoretical calculation (hidden_size={hidden_size}, r={r}):") + print(f" Per module: 2 * {hidden_size} * {r} = {2 * hidden_size * r:,} parameters") + print(f" Total ({num_modules} modules): {2 * hidden_size * r * num_modules:,} parameters") + else: + print(f"\nNo saved adapter found at: {adapter_path}") + print("Run training first to generate the adapter.") + + # Show theoretical sizes based on config + r = adapter_config['r'] + print(f"\nTheoretical LoRA sizes with r={r}:") + print(f" For hidden_size=768 (DialoGPT-small): {2 * 768 * r:,} params per module") + print(f" For hidden_size=1024 (medium models): {2 * 1024 * r:,} params per module") + print(f" For hidden_size=1280 (GPT-2 large): {2 * 1280 * r:,} params per module") + + +if __name__ == "__main__": + analyze_adapter_sizes() \ No newline at end of file diff --git a/scripts/check_vram.py b/scripts/check_vram.py new file mode 100644 index 0000000..869dc02 --- /dev/null +++ b/scripts/check_vram.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +""" +Check VRAM usage and model memory requirements +""" + +import torch +import psutil +import sys +from pathlib import Path +import yaml + +# Add src to path +sys.path.append(str(Path(__file__).parent.parent)) + +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig + + +def get_memory_info(): + """Get current memory usage""" + if torch.cuda.is_available(): + print("=== CUDA Information ===") + print(f"CUDA available: {torch.cuda.is_available()}") + print(f"CUDA device: {torch.cuda.get_device_name(0)}") + print(f"CUDA device count: {torch.cuda.device_count()}") + + # Get VRAM info + vram_total = torch.cuda.get_device_properties(0).total_memory / 1024**3 + vram_reserved = torch.cuda.memory_reserved(0) / 1024**3 + vram_allocated = torch.cuda.memory_allocated(0) / 1024**3 + vram_free = vram_total - vram_allocated + + print(f"\n=== VRAM Usage ===") + print(f"Total VRAM: {vram_total:.2f} GB") + print(f"Allocated VRAM: {vram_allocated:.2f} GB") + print(f"Reserved VRAM: {vram_reserved:.2f} GB") + print(f"Free VRAM: {vram_free:.2f} GB") + else: + print("CUDA not available!") + + # Get system RAM info + ram = psutil.virtual_memory() + print(f"\n=== System RAM ===") + print(f"Total RAM: {ram.total / 1024**3:.2f} GB") + print(f"Available RAM: {ram.available / 1024**3:.2f} GB") + print(f"Used RAM: {ram.used / 1024**3:.2f} GB ({ram.percent}%)") + + +def estimate_model_size(model_name: str, quantization: str = None): + """Estimate model memory requirements""" + print(f"\n=== Model Memory Estimation ===") + print(f"Model: {model_name}") + + # Common model sizes (in billions of parameters) + model_sizes = { + "gemma-2-2b": 2.5, + "gemma-3-1b": 1.2, + "llama-3.2-8b": 8, + "llama-3.2-13b": 13, + "llama-3.2-70b": 70, + } + + # Find model size + model_key = None + for key in model_sizes: + if key in model_name.lower(): + model_key = key + break + + if model_key: + params_billions = model_sizes[model_key] + + # Memory estimates (rough) + fp32_gb = params_billions * 4 # 4 bytes per parameter + fp16_gb = params_billions * 2 # 2 bytes per parameter + int8_gb = params_billions * 1 # 1 byte per parameter + int4_gb = params_billions * 0.5 # 0.5 bytes per parameter + + print(f"Estimated parameters: {params_billions}B") + print(f"Memory requirements:") + print(f" FP32: ~{fp32_gb:.1f} GB") + print(f" FP16/BF16: ~{fp16_gb:.1f} GB") + print(f" INT8: ~{int8_gb:.1f} GB") + print(f" INT4 (QLoRA): ~{int4_gb:.1f} GB") + + # Add overhead for activations and gradients + print(f"\nWith training overhead:") + print(f" FP16 + LoRA: ~{fp16_gb * 1.5:.1f} GB") + print(f" INT4 + QLoRA: ~{int4_gb * 1.5:.1f} GB") + else: + print("Model size not recognized, unable to estimate memory requirements") + + +def suggest_offloading_strategies(): + """Suggest CPU offloading strategies""" + print("\n=== CPU Offloading Strategies ===") + print("\n1. **Device Map Auto with CPU Offload**") + print(" ```python") + print(" device_map = {") + print(" 'model.embed_tokens': 'cpu',") + print(" 'model.layers.0': 0, # GPU") + print(" 'model.layers.1': 0, # GPU") + print(" 'model.layers.2': 'cpu', # CPU") + print(" # ... distribute layers between GPU and CPU") + print(" }") + print(" ```") + + print("\n2. **Accelerate's CPU Offload**") + print(" ```yaml") + print(" model:") + print(" device_map: 'auto'") + print(" max_memory:") + print(" 0: '4GB' # Limit GPU memory") + print(" 'cpu': '20GB' # Allow CPU memory") + print(" ```") + + print("\n3. **DeepSpeed ZeRO-Offload**") + print(" - ZeRO-2: Offload optimizer states to CPU") + print(" - ZeRO-3: Offload optimizer states and parameters to CPU") + print(" ```yaml") + print(" deepspeed:") + print(" zero_optimization:") + print(" stage: 2") + print(" offload_optimizer:") + print(" device: 'cpu'") + print(" ```") + + print("\n4. **Gradient Checkpointing + CPU Offload**") + print(" - Trade compute for memory") + print(" - Combine with layer-wise CPU offloading") + + print("\n5. **QLoRA with CPU Offload**") + print(" - 4-bit quantization reduces base model size") + print(" - Only LoRA parameters on GPU") + print(" - Base model layers can be on CPU") + + +def check_config_compatibility(config_path: str): + """Check if config is compatible with CPU offloading""" + if Path(config_path).exists(): + with open(config_path) as f: + config = yaml.safe_load(f) + + print(f"\n=== Config Analysis: {config_path} ===") + model_config = config.get("model", {}) + + print(f"Current settings:") + print(f" 4-bit quantization: {model_config.get('load_in_4bit', False)}") + print(f" Gradient checkpointing: {model_config.get('gradient_checkpointing', False)}") + print(f" Device map: {model_config.get('device_map', 'None')}") + + if model_config.get('load_in_4bit', False): + print("✓ Already using 4-bit quantization (good for memory)") + else: + print("✗ Consider enabling 4-bit quantization") + + if not model_config.get('gradient_checkpointing', False): + print("✗ Consider enabling gradient checkpointing") + + +def main(): + """Main function""" + print("VRAM and Memory Analysis for Progressive LLM Training") + print("=" * 60) + + # Get memory info + get_memory_info() + + # Estimate model sizes + models = [ + "google/gemma-2-2b-it", + "google/gemma-3-1b-pt", + "meta-llama/Llama-3.2-8B", + ] + + for model in models: + estimate_model_size(model) + + # Suggest strategies + suggest_offloading_strategies() + + # Check configs + configs = [ + "config/training_config_gemma3_1b.yaml", + "config/training_config_gemma2_small.yaml", + ] + + for config in configs: + check_config_compatibility(config) + + print("\n=== Recommendations ===") + print("1. Start with QLoRA (4-bit) if not already enabled") + print("2. Use device_map with max_memory limits") + print("3. Enable gradient checkpointing") + print("4. Consider DeepSpeed for advanced offloading") + print("5. Monitor actual usage during training") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/compare_models_tui.py b/scripts/compare_models_tui.py new file mode 100755 index 0000000..b2913b1 --- /dev/null +++ b/scripts/compare_models_tui.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +""" +TUI for comparing original and trained models +""" + +import sys +from pathlib import Path +import yaml +import torch +from rich.console import Console +from rich.panel import Panel +from rich.columns import Columns +from rich.prompt import Prompt +from rich.text import Text +from rich.layout import Layout +from rich.live import Live +from rich.table import Table +import time + +# Add src to path +sys.path.append(str(Path(__file__).parent.parent)) + +from src.progressive_model import ProgressiveReasoningModel + + +class ModelCompareTUI: + def __init__(self, config_path: str = "config/training_config.yaml"): + self.console = Console() + + # Load configuration + with open(config_path) as f: + self.config = yaml.safe_load(f) + + # Initialize models + self.console.print("[yellow]Loading models...[/yellow]") + + # Original model + self.original_model = ProgressiveReasoningModel(self.config) + self.original_model.setup_base_model() + + # Trained model + self.trained_model = ProgressiveReasoningModel(self.config) + self.trained_model.setup_base_model() + + # Load the trained adapter if it exists + adapter_path = Path(self.config["experiment"]["output_dir"]) / "adapters" / "basic_cot" + if adapter_path.exists(): + self.console.print(f"[green]Loading trained adapter from: {adapter_path}[/green]") + self.trained_model.load_for_inference(["basic_cot"]) + else: + self.console.print("[red]No trained adapter found. Please run training first.[/red]") + self.console.print("[yellow]Both models will show original behavior.[/yellow]") + + self.console.print("[green]Models loaded successfully![/green]\n") + + def generate_response(self, model, prompt: str, with_think_tags: bool = True) -> str: + """Generate response from a model""" + # For trained model, encourage think tags + if with_think_tags and model == self.trained_model: + formatted_prompt = f"{prompt}\n\nPlease think step by step." + else: + formatted_prompt = prompt + + inputs = model.tokenizer(formatted_prompt, return_tensors="pt").to(model.model.device) + + with torch.no_grad(): + outputs = model.model.generate( + **inputs, + max_length=512, + temperature=0.7, + do_sample=True, + top_p=0.95, + pad_token_id=model.tokenizer.pad_token_id, + eos_token_id=model.tokenizer.eos_token_id + ) + + response = model.tokenizer.decode(outputs[0], skip_special_tokens=True) + + # Extract response after prompt + response = response[len(formatted_prompt):].strip() + + return response + + def create_comparison_panel(self, prompt: str, original_response: str, trained_response: str) -> Panel: + """Create a panel showing the comparison""" + # Create table + table = Table(show_header=True, header_style="bold magenta", expand=True) + table.add_column("Original Model", style="cyan", width=50) + table.add_column("Trained Model (with CoT)", style="green", width=50) + + table.add_row(original_response, trained_response) + + return Panel( + table, + title=f"[bold yellow]Prompt: {prompt}[/bold yellow]", + border_style="blue" + ) + + def run_interactive_mode(self): + """Run interactive comparison mode""" + self.console.print("\n[bold cyan]Model Comparison TUI[/bold cyan]") + self.console.print("Compare responses from original and trained models\n") + self.console.print("[dim]Type 'quit' or 'exit' to leave[/dim]\n") + + while True: + # Get user prompt + prompt = Prompt.ask("\n[bold yellow]Enter your prompt[/bold yellow]") + + if prompt.lower() in ['quit', 'exit']: + self.console.print("\n[yellow]Goodbye![/yellow]") + break + + # Generate responses + self.console.print("\n[dim]Generating responses...[/dim]") + + start_time = time.time() + original_response = self.generate_response(self.original_model, prompt, with_think_tags=False) + original_time = time.time() - start_time + + start_time = time.time() + trained_response = self.generate_response(self.trained_model, prompt, with_think_tags=True) + trained_time = time.time() - start_time + + # Display comparison + panel = self.create_comparison_panel(prompt, original_response, trained_response) + self.console.print(panel) + + # Show generation times + self.console.print(f"\n[dim]Generation times - Original: {original_time:.2f}s, Trained: {trained_time:.2f}s[/dim]") + + def run_benchmark_mode(self): + """Run benchmark with predefined prompts""" + test_prompts = [ + "What is 156 + 389?", + "If I have 23 apples and buy 17 more, how many do I have?", + "A store has 145 items. If 38 are sold, how many remain?", + "What is 45 * 12?", + "Explain why 2 + 2 = 4", + "If a train travels 80 km/h for 2.5 hours, how far does it go?", + "What is the sum of all numbers from 1 to 10?", + "How many minutes are in 3.5 hours?", + ] + + self.console.print("\n[bold cyan]Running Benchmark Comparison[/bold cyan]\n") + + for i, prompt in enumerate(test_prompts, 1): + self.console.print(f"[bold]Test {i}/{len(test_prompts)}[/bold]") + + # Generate responses + original_response = self.generate_response(self.original_model, prompt, with_think_tags=False) + trained_response = self.generate_response(self.trained_model, prompt, with_think_tags=True) + + # Display comparison + panel = self.create_comparison_panel(prompt, original_response, trained_response) + self.console.print(panel) + self.console.print("") + + self.console.print("[green]Benchmark completed![/green]") + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="Compare original and trained models") + parser.add_argument("--mode", choices=["interactive", "benchmark"], default="interactive", + help="Mode to run the comparison") + parser.add_argument("--config", default="config/training_config.yaml", + help="Path to configuration file") + + args = parser.parse_args() + + # Create TUI + tui = ModelCompareTUI(args.config) + + # Run in selected mode + if args.mode == "interactive": + tui.run_interactive_mode() + else: + tui.run_benchmark_mode() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/evaluate.py b/scripts/evaluate.py new file mode 100755 index 0000000..485ca76 --- /dev/null +++ b/scripts/evaluate.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +""" +Evaluation script for progressive model +""" + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +from src.progressive_model import ProgressiveReasoningModel +import yaml + + +def evaluate_reasoning(model_wrapper, test_prompts): + """Evaluate model on test prompts""" + results = [] + + for prompt in test_prompts: + print(f"\nPrompt: {prompt}") + response = model_wrapper.generate_with_reasoning(prompt) + print(f"Response: {response}") + results.append({ + "prompt": prompt, + "response": response + }) + + return results + + +def main(): + # Load config + with open("config/training_config.yaml") as f: + config = yaml.safe_load(f) + + # Initialize model + model_wrapper = ProgressiveReasoningModel(config) + model_wrapper.setup_base_model() + + # Test different adapters + test_prompts = [ + "What is 156 + 389?", + "If a train travels 80 km/h for 2.5 hours, how far does it go?", + "Explain why the sky is blue.", + ] + + # Test each adapter + for adapter_name in ["basic_cot", "math_reasoning", "complex_reasoning"]: + if adapter_name in model_wrapper.adapters: + print(f"\n{'='*50}") + print(f"Testing adapter: {adapter_name}") + print(f"{'='*50}") + + model_wrapper.load_for_inference([adapter_name]) + results = evaluate_reasoning(model_wrapper, test_prompts) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/simple_compare.py b/scripts/simple_compare.py new file mode 100755 index 0000000..dcea3ff --- /dev/null +++ b/scripts/simple_compare.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +""" +Simple comparison script without rich TUI +""" + +import sys +from pathlib import Path +import yaml +import torch +import argparse + +# Add src to path +sys.path.append(str(Path(__file__).parent.parent)) + +from src.progressive_model import ProgressiveReasoningModel + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compare original and trained models") + parser.add_argument( + "--config", "-c", + type=str, + default="config/training_config_gemma2_small.yaml", + help="Path to configuration file" + ) + parser.add_argument( + "--adapter", "-a", + type=str, + default="basic_cot", + help="Adapter name to load for comparison" + ) + parser.add_argument( + "--max-length", + type=int, + default=512, + help="Maximum generation length" + ) + return parser.parse_args() + + +def load_config(config_path): + """Load configuration from file""" + config_path = Path(config_path) + if not config_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + with open(config_path) as f: + config = yaml.safe_load(f) + return config + + +def generate_response(model, tokenizer, prompt, max_length=512): + """Generate response using the model""" + # Format prompt for Gemma + formatted_prompt = f"user\n{prompt}\nmodel\n" + + # Tokenize + inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) + + # Generate + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_length=len(inputs["input_ids"][0]) + max_length, + temperature=0.7, + do_sample=True, + top_p=0.9, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + repetition_penalty=1.1, + ) + + # Decode + response = tokenizer.decode(outputs[0], skip_special_tokens=True) + + # Extract only the model's response + if "model" in response: + response = response.split("model")[-1].strip() + + return response + + +def main(): + args = parse_args() + + try: + config = load_config(args.config) + except FileNotFoundError as e: + print(f"Error: {e}") + return + + print(f"Progressive Model Comparison") + print(f"Config: {args.config}") + print(f"Base model: {config['experiment']['base_model']}") + print(f"Adapter: {args.adapter}") + print("="*60) + + print("Loading models...") + + # Original model (no adapter) + print("Loading original model...") + original_model = ProgressiveReasoningModel(config) + original_model.setup_base_model() + + # Trained model (with adapter) + print("Loading trained model...") + trained_model = ProgressiveReasoningModel(config) + trained_model.setup_base_model() + + # Load the trained adapter if it exists + adapter_path = Path(config["experiment"]["output_dir"]) / "adapters" / args.adapter + if adapter_path.exists(): + print(f"Loading trained adapter from: {adapter_path}") + try: + trained_model.load_for_inference([args.adapter]) + print("Adapter loaded successfully!") + except Exception as e: + print(f"Error loading adapter: {e}") + print("Will compare with base model instead.") + else: + print(f"No trained adapter found at: {adapter_path}") + print("Available adapters:") + adapters_dir = Path(config["experiment"]["output_dir"]) / "adapters" + if adapters_dir.exists(): + for adapter_dir in adapters_dir.iterdir(): + if adapter_dir.is_dir(): + print(f" - {adapter_dir.name}") + else: + print(" No adapters directory found.") + print("Both models will show original behavior.") + + print("\nModels loaded! Enter prompts to compare (type 'quit' to exit)") + print("Examples:") + print(" - What is 25 + 17?") + print(" - Explain why the sky is blue") + print(" - Solve this step by step: If I have 10 apples and give away 3, how many do I have left?") + print() + + while True: + try: + prompt = input("\nPrompt: ").strip() + if prompt.lower() in ['quit', 'exit', 'q']: + break + + if not prompt: + continue + + print(f"\n{'='*60}") + print("ORIGINAL MODEL (No fine-tuning)") + print("="*60) + + try: + original_response = generate_response( + original_model.model, + original_model.tokenizer, + prompt, + args.max_length + ) + print(original_response) + except Exception as e: + print(f"Error generating original response: {e}") + + print(f"\n{'='*60}") + print(f"TRAINED MODEL (With {args.adapter} adapter)") + print("="*60) + + try: + # Add CoT prompt for trained model + cot_prompt = f"{prompt}\n\nPlease think step by step using tags." + trained_response = generate_response( + trained_model.model, + trained_model.tokenizer, + cot_prompt, + args.max_length + ) + print(trained_response) + except Exception as e: + print(f"Error generating trained response: {e}") + + except KeyboardInterrupt: + print("\nExiting...") + break + except Exception as e: + print(f"Error: {e}") + continue + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/train_progressive.py b/scripts/train_progressive.py new file mode 100755 index 0000000..d3cd938 --- /dev/null +++ b/scripts/train_progressive.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +""" +Main training script for progressive reasoning model +""" + +import sys +import yaml +import argparse +from pathlib import Path + +# Add src to path +sys.path.append(str(Path(__file__).parent.parent)) + +from src.progressive_model import ProgressiveReasoningModel +from src.training import ProgressiveTrainer +from src.data_utils import prepare_sample_datasets + + +def parse_args(): + """Parse command line arguments""" + parser = argparse.ArgumentParser( + description="Progressive LLM Training for 松尾研LLMコンペ2025", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Use default config + python scripts/train_progressive.py + + # Use specific config file + python scripts/train_progressive.py --config config/training_config_large.yaml + + # Use config with custom path + python scripts/train_progressive.py --config /path/to/my_config.yaml + + # Prepare sample datasets + python scripts/train_progressive.py --prepare-data + """ + ) + + parser.add_argument( + "--config", "-c", + type=str, + default="config/training_config.yaml", + help="Path to the training configuration file (default: config/training_config.yaml)" + ) + + parser.add_argument( + "--prepare-data", + action="store_true", + help="Prepare sample datasets before training" + ) + + parser.add_argument( + "--dry-run", + action="store_true", + help="Load config and model but skip training (for testing)" + ) + + return parser.parse_args() + + +def load_config(config_path: str) -> dict: + """Load configuration from file""" + config_path = Path(config_path) + + if not config_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + print(f"Loading configuration from: {config_path}") + + with open(config_path) as f: + config = yaml.safe_load(f) + + return config + + +def main(): + args = parse_args() + + print("Progressive LLM Training for 松尾研LLMコンペ2025") + print("=" * 50) + + # Load configuration + try: + config = load_config(args.config) + except FileNotFoundError as e: + print(f"Error: {e}") + print("Available config files:") + config_dir = Path("config") + if config_dir.exists(): + for config_file in config_dir.glob("*.yaml"): + print(f" {config_file}") + sys.exit(1) + except Exception as e: + print(f"Error loading config: {e}") + sys.exit(1) + + # Print configuration info + print(f"Experiment: {config['experiment']['name']}") + print(f"Base model: {config['experiment']['base_model']}") + print(f"Output directory: {config['experiment']['output_dir']}") + print(f"Stages: {len(config['progressive_stages'])}") + + # Prepare sample datasets if requested + if args.prepare_data: + print("\nPreparing sample datasets...") + prepare_sample_datasets() + print("Sample datasets prepared.") + + # Initialize model wrapper + print("\nInitializing model...") + model_wrapper = ProgressiveReasoningModel(config) + model_wrapper.setup_base_model() + + if args.dry_run: + print("\nDry run completed. Model loaded successfully.") + return + + # Initialize trainer + print("\nInitializing trainer...") + trainer = ProgressiveTrainer(model_wrapper, config) + + # Run progressive training + print("\nStarting progressive training...") + trainer.run_progressive_training() + + print("\nTraining completed successfully!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data_utils.py b/src/data_utils.py new file mode 100644 index 0000000..0c7370a --- /dev/null +++ b/src/data_utils.py @@ -0,0 +1,88 @@ +import json +import jsonlines +from typing import List, Dict +from pathlib import Path +import random + + +def create_think_tag_example(question: str, reasoning: str, answer: str) -> Dict: + """Create training example with think tags""" + output = f"\n{reasoning}\n\n\n{answer}" + + return { + "input": question, + "output": output + } + + +def prepare_basic_cot_data(output_dir: str, num_examples: int = 1000): + """Create basic Chain-of-Thought examples""" + output_path = Path(output_dir) / "basic_cot" + output_path.mkdir(parents=True, exist_ok=True) + + examples = [] + + # Simple arithmetic examples + for i in range(num_examples // 2): + a = random.randint(10, 100) + b = random.randint(10, 100) + question = f"What is {a} + {b}?" + reasoning = f"To find {a} + {b}, I need to add these two numbers together.\n{a} + {b} = {a + b}" + answer = f"The answer is {a + b}." + + examples.append(create_think_tag_example(question, reasoning, answer)) + + # Simple word problems + templates = [ + { + "question": "If I have {a} apples and buy {b} more, how many apples do I have?", + "reasoning": "Starting with {a} apples, then adding {b} more apples.\nTotal: {a} + {b} = {result}", + "answer": "I have {result} apples." + }, + { + "question": "A store has {a} items. If {b} are sold, how many remain?", + "reasoning": "Starting amount: {a} items\nSold: {b} items\nRemaining: {a} - {b} = {result}", + "answer": "There are {result} items remaining." + } + ] + + for i in range(num_examples // 2): + template = random.choice(templates) + a = random.randint(20, 200) + b = random.randint(10, min(50, a)) + + if "+" in template["reasoning"]: + result = a + b + else: + result = a - b + + question = template["question"].format(a=a, b=b) + reasoning = template["reasoning"].format(a=a, b=b, result=result) + answer = template["answer"].format(result=result) + + examples.append(create_think_tag_example(question, reasoning, answer)) + + # Save to jsonl + output_file = output_path / "train.jsonl" + with jsonlines.open(output_file, "w") as writer: + writer.write_all(examples) + + print(f"Created {len(examples)} basic CoT examples at: {output_file}") + + +def prepare_sample_datasets(base_dir: str = "./data"): + """Prepare sample datasets for all stages""" + base_path = Path(base_dir) + + # Basic CoT + prepare_basic_cot_data(base_path) + + # Math reasoning (placeholder) + math_path = base_path / "math_reasoning" + math_path.mkdir(parents=True, exist_ok=True) + + # Complex reasoning (placeholder) + complex_path = base_path / "complex_reasoning" + complex_path.mkdir(parents=True, exist_ok=True) + + print(f"Sample datasets prepared in: {base_path}") \ No newline at end of file diff --git a/src/progressive_model.py b/src/progressive_model.py new file mode 100644 index 0000000..6c5e669 --- /dev/null +++ b/src/progressive_model.py @@ -0,0 +1,366 @@ +import torch +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + TrainingArguments +) +from peft import ( + LoraConfig, + PeftModel, + TaskType, + get_peft_model, + prepare_model_for_kbit_training +) +from typing import Dict, List, Optional, Tuple +import json +from pathlib import Path + + +class ProgressiveReasoningModel: + """Progressive training approach for reasoning models""" + + def __init__(self, config: dict): + self.config = config + self.base_model_name = config["experiment"]["base_model"] + self.output_dir = Path(config["experiment"]["output_dir"]) + self.output_dir.mkdir(parents=True, exist_ok=True) + + self.model = None + self.tokenizer = None + self.adapters = {} + self.training_history = [] + + def setup_base_model(self): + """Initialize base model with quantization""" + print(f"Loading base model: {self.base_model_name}") + + # Check if quantization is enabled + if self.config["model"].get("load_in_4bit", False): + # BitsAndBytes config for 4-bit quantization + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=getattr(torch, self.config["model"]["bnb_4bit_compute_dtype"]), + bnb_4bit_use_double_quant=self.config["model"]["bnb_4bit_use_double_quant"], + bnb_4bit_quant_type=self.config["model"].get("bnb_4bit_quant_type", "nf4") + ) + quantization_config = bnb_config + else: + quantization_config = None + + # Model loading arguments + model_kwargs = { + "device_map": self.config["model"]["device_map"], + "trust_remote_code": True, + "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32, + } + + # Add authentication token if provided + if "hf_token" in self.config["model"] and self.config["model"]["hf_token"]: + model_kwargs["token"] = self.config["model"]["hf_token"] + + # Add max_memory configuration for CPU offloading + if "max_memory" in self.config["model"]: + model_kwargs["max_memory"] = self.config["model"]["max_memory"] + print(f"Using max_memory configuration: {model_kwargs['max_memory']}") + + # Add offload folder if specified + if "offload_folder" in self.config["model"]: + model_kwargs["offload_folder"] = self.config["model"]["offload_folder"] + model_kwargs["offload_state_dict"] = True + print(f"Using offload folder: {model_kwargs['offload_folder']}") + + # Note: llm_int8_enable_fp32_cpu_offload is not supported for all models + # Only add it if we're not using Gemma models + if (quantization_config and + self.config["model"].get("llm_int8_enable_fp32_cpu_offload", False) and + "gemma" not in self.base_model_name.lower()): + model_kwargs["llm_int8_enable_fp32_cpu_offload"] = True + print("Enabled FP32 CPU offload for quantized model") + + # Add quantization config if enabled + if quantization_config: + model_kwargs["quantization_config"] = quantization_config + + # Add attention implementation + if self.config["model"].get("use_flash_attention_2", False): + model_kwargs["attn_implementation"] = "flash_attention_2" + elif self.config["model"].get("use_eager_attention", False): + model_kwargs["attn_implementation"] = "eager" + + # Load model + print("Loading model with the following kwargs:") + for k, v in model_kwargs.items(): + if k != "quantization_config": + print(f" {k}: {v}") + else: + print(f" {k}: ") + + try: + self.model = AutoModelForCausalLM.from_pretrained( + self.base_model_name, + **model_kwargs + ) + except Exception as e: + print(f"Error loading model: {e}") + # Try without some problematic kwargs + if "offload_folder" in model_kwargs: + print("Retrying without offload_folder...") + del model_kwargs["offload_folder"] + del model_kwargs["offload_state_dict"] + self.model = AutoModelForCausalLM.from_pretrained( + self.base_model_name, + **model_kwargs + ) + + # Prepare for k-bit training if using quantization + if quantization_config: + self.model = prepare_model_for_kbit_training(self.model) + + # Disable gradient checkpointing for now to avoid conflicts + # Enable gradient checkpointing if requested (but disable use_cache) + # if self.config["model"].get("gradient_checkpointing", False): + # self.model.gradient_checkpointing_enable() + # self.model.config.use_cache = False + # print("Gradient checkpointing enabled, use_cache disabled") + + # Explicitly disable use_cache to avoid conflicts + if hasattr(self.model, 'config'): + self.model.config.use_cache = False + + # Load tokenizer + tokenizer_kwargs = {"trust_remote_code": True} + if "hf_token" in self.config["model"] and self.config["model"]["hf_token"]: + tokenizer_kwargs["token"] = self.config["model"]["hf_token"] + + self.tokenizer = AutoTokenizer.from_pretrained( + self.base_model_name, + **tokenizer_kwargs + ) + + # Set padding token and other special tokens for Gemma + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + # For Gemma models, ensure special tokens are set + if "gemma" in self.base_model_name.lower(): + print("Configuring Gemma-specific tokenizer settings") + # Add special tokens if they don't exist + special_tokens = { + "bos_token": "", + "eos_token": "", + "pad_token": "", + } + + # Only add tokens that don't already exist + tokens_to_add = {} + for token_name, token_value in special_tokens.items(): + if getattr(self.tokenizer, token_name, None) is None: + tokens_to_add[token_name] = token_value + + if tokens_to_add: + num_added = self.tokenizer.add_special_tokens(tokens_to_add) + print(f"Added special tokens: {tokens_to_add}") + if num_added > 0: + # Resize model embeddings to accommodate new tokens + self.model.resize_token_embeddings(len(self.tokenizer)) + print(f"Resized model embeddings to {len(self.tokenizer)} tokens") + + # Set appropriate model_max_length for Gemma + if hasattr(self.tokenizer, 'model_max_length') and self.tokenizer.model_max_length > 8192: + self.tokenizer.model_max_length = 8192 + print(f"Set tokenizer model_max_length to {self.tokenizer.model_max_length}") + + # Debug: print model structure for target module identification + print("Model structure:") + for name, module in self.model.named_modules(): + if any(target in name for target in ['attn', 'proj', 'mlp', 'gate', 'up', 'down']): + print(f" {name}: {type(module).__name__}") + + print("Base model loaded successfully") + + def get_target_modules(self, suggested_modules): + """Auto-detect valid target modules for the model""" + valid_modules = [] + all_modules = [name for name, _ in self.model.named_modules()] + + # Check each suggested module + for module_name in suggested_modules: + # Find modules that contain this name + matching_modules = [name for name in all_modules if module_name in name] + if matching_modules: + valid_modules.append(module_name) + print(f" Found target module: {module_name} (matches: {len(matching_modules)} modules)") + else: + print(f" Warning: target module '{module_name}' not found in model") + + # If no valid modules found, try common alternatives + if not valid_modules: + print(" No suggested modules found, trying common alternatives...") + common_alternatives = [ + "q_proj", "k_proj", "v_proj", "o_proj", # Common attention + "gate_proj", "up_proj", "down_proj", # Common MLP + "self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj", "self_attn.o_proj", # Full path + "mlp.gate_proj", "mlp.up_proj", "mlp.down_proj", # Full MLP path + ] + + for module_name in common_alternatives: + matching_modules = [name for name in all_modules if module_name in name] + if matching_modules: + valid_modules.append(module_name) + print(f" Found alternative target module: {module_name}") + if len(valid_modules) >= 2: # At least 2 modules + break + + if not valid_modules: + print(" ERROR: No valid target modules found!") + print(" Available modules containing 'proj' or 'attn':") + for name in all_modules: + if any(keyword in name.lower() for keyword in ['proj', 'attn', 'mlp']): + print(f" {name}") + # Fallback to a basic module that should exist + valid_modules = ["embed_tokens"] + + return valid_modules + + def create_adapter(self, stage_config: dict) -> LoraConfig: + """Create LoRA adapter configuration""" + adapter_config = stage_config["adapter_config"] + + # Get initialization method from config, default to True for identity init + init_method = adapter_config.get("init_lora_weights", True) + + # Auto-detect valid target modules + suggested_modules = adapter_config["target_modules"] + valid_modules = self.get_target_modules(suggested_modules) + + print(f"Using target modules: {valid_modules}") + + return LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=adapter_config["r"], + lora_alpha=adapter_config["lora_alpha"], + lora_dropout=adapter_config["lora_dropout"], + target_modules=valid_modules, + bias="none", + init_lora_weights=init_method # Initialize LoRA weights (True = identity, "gaussian" = random) + ) + + def add_progressive_adapter(self, stage_name: str, stage_config: dict): + """Add a new adapter for progressive training""" + print(f"\nAdding adapter for stage: {stage_name}") + + # Check if we should inherit from previous adapter + if "inherit_from" in stage_config and stage_config["inherit_from"] in self.adapters: + print(f"Inheriting from: {stage_config['inherit_from']}") + # Load previous adapter as base + prev_adapter_path = self.adapters[stage_config["inherit_from"]] + self.model = PeftModel.from_pretrained( + self.model, + prev_adapter_path, + is_trainable=True + ) + # Merge and unload to incorporate previous learning + self.model = self.model.merge_and_unload() + + # Create new adapter config + lora_config = self.create_adapter(stage_config) + + # Add adapter to model + self.model = get_peft_model(self.model, lora_config) + + # Ensure model is in training mode + self.model.train() + + # Print trainable parameters + self.model.print_trainable_parameters() + + # Debug: check if any parameters require gradients + trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + total_params = sum(p.numel() for p in self.model.parameters()) + print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)") + + # List parameters that require gradients + grad_params = [name for name, param in self.model.named_parameters() if param.requires_grad] + print(f"Parameters requiring gradients: {len(grad_params)} parameters") + if len(grad_params) > 0: + print(f"First few: {grad_params[:5]}") + else: + print("WARNING: No parameters require gradients!") + + # Save adapter path + adapter_path = self.output_dir / "adapters" / stage_name + adapter_path.mkdir(parents=True, exist_ok=True) + self.adapters[stage_name] = str(adapter_path) + + def save_adapter(self, stage_name: str): + """Save current adapter""" + if stage_name in self.adapters: + print(f"Saving adapter: {stage_name}") + self.model.save_pretrained(self.adapters[stage_name]) + # Also save tokenizer for convenience + self.tokenizer.save_pretrained(self.adapters[stage_name]) + + def load_for_inference(self, adapter_names: List[str], weights: Optional[Dict[str, float]] = None): + """Load model with specific adapters for inference""" + if len(adapter_names) == 1: + # Single adapter + adapter_name = adapter_names[0] + + # Check if adapter path is in memory + if adapter_name in self.adapters: + adapter_path = self.adapters[adapter_name] + else: + # Try to find adapter in output directory + adapter_path = self.output_dir / "adapters" / adapter_name + if not adapter_path.exists(): + raise ValueError(f"Adapter {adapter_name} not found at {adapter_path}") + adapter_path = str(adapter_path) + + print(f"Loading adapter from: {adapter_path}") + self.model = PeftModel.from_pretrained( + self.model, + adapter_path + ) + else: + # Multiple adapters - load and combine + # This is a simplified version - real implementation would need adapter composition + print("Multi-adapter inference not fully implemented in this bootstrap") + # For now, just load the last adapter + adapter_name = adapter_names[-1] + if adapter_name in self.adapters: + adapter_path = self.adapters[adapter_name] + else: + adapter_path = str(self.output_dir / "adapters" / adapter_name) + self.model = PeftModel.from_pretrained( + self.model, + adapter_path + ) + + def generate_with_reasoning(self, prompt: str, max_length: int = 2048) -> str: + """Generate response with reasoning""" + # Format prompt with think tags expectation + formatted_prompt = f"{prompt}\n\nPlease think step by step using tags before providing your answer." + + # Tokenize + inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device) + + # Generate + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_length=max_length, + temperature=0.7, + do_sample=True, + top_p=0.95, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id + ) + + # Decode + response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + + # Extract response after prompt + response = response[len(formatted_prompt):].strip() + + return response \ No newline at end of file diff --git a/src/training.py b/src/training.py new file mode 100644 index 0000000..af6f63a --- /dev/null +++ b/src/training.py @@ -0,0 +1,450 @@ +from transformers import TrainingArguments +from trl import SFTTrainer +from datasets import load_dataset, Dataset +import torch +from typing import Dict, List +import json +import jsonlines +from pathlib import Path + + +class ProgressiveTrainer: + """Handle progressive training stages""" + + def __init__(self, model_wrapper, config: dict): + self.model_wrapper = model_wrapper + self.config = config + self.training_history = [] + + def load_dataset(self, dataset_path: str, stage_config: dict = None) -> Dataset: + """Load dataset from jsonl files or HuggingFace datasets""" + print(f"Loading dataset from path: {dataset_path}") + + # Check if it's a HuggingFace dataset (contains '/') + if '/' in dataset_path and not Path(dataset_path).exists(): + print(f"Loading HuggingFace dataset: {dataset_path}") + return self.load_huggingface_dataset(dataset_path, stage_config) + + # Load local dataset + data = [] + print(f"Current working directory: {Path.cwd()}") + + # Support both single file and directory + path = Path(dataset_path) + print(f"Path exists: {path.exists()}") + print(f"Is file: {path.is_file()}") + print(f"Is directory: {path.is_dir()}") + + if path.is_file(): + files = [path] + else: + files = list(path.glob("*.jsonl")) + + print(f"Found {len(files)} files to load") + for f in files: + print(f" - {f}") + + for file_path in files: + print(f"Loading file: {file_path}") + try: + with jsonlines.open(file_path) as reader: + count = 0 + for item in reader: + # Format for chat template + formatted = { + "messages": [ + {"role": "user", "content": item["input"]}, + {"role": "assistant", "content": item["output"]} + ] + } + data.append(formatted) + count += 1 + print(f" Loaded {count} examples from {file_path}") + except Exception as e: + print(f" Error loading file {file_path}: {e}") + + print(f"Total examples loaded: {len(data)}") + return Dataset.from_list(data) + + def load_huggingface_dataset(self, dataset_name: str, stage_config: dict) -> Dataset: + """Load dataset from HuggingFace""" + try: + dataset_config = stage_config.get("dataset_config", {}) if stage_config else {} + + # Default settings + split = dataset_config.get("split", "train") + max_samples = dataset_config.get("max_samples", None) + streaming = dataset_config.get("streaming", False) + + print(f"Loading HuggingFace dataset: {dataset_name}") + print(f" Split: {split}") + print(f" Max samples: {max_samples}") + print(f" Streaming: {streaming}") + + # Load dataset + if streaming: + dataset = load_dataset(dataset_name, split=split, streaming=True) + if max_samples: + dataset = dataset.take(max_samples) + # Convert streaming dataset to regular dataset + data = [] + count = 0 + for item in dataset: + data.append(item) + count += 1 + if count % 1000 == 0: + print(f" Loaded {count} examples...") + if max_samples and count >= max_samples: + break + dataset = Dataset.from_list(data) + else: + dataset = load_dataset(dataset_name, split=split) + if max_samples: + dataset = dataset.select(range(min(max_samples, len(dataset)))) + + print(f" Loaded dataset with {len(dataset)} examples") + print(f" Dataset columns: {dataset.column_names}") + if len(dataset) > 0: + print(f" First example: {dataset[0]}") + + # Convert to our expected format based on dataset name + if "math" in dataset_name.lower(): + return self.convert_math_dataset(dataset) + elif "mixture-of-thoughts" in dataset_name.lower(): + return self.convert_mixture_of_thoughts_dataset(dataset) + else: + return self.convert_generic_dataset(dataset) + + except Exception as e: + print(f"Error loading HuggingFace dataset {dataset_name}: {e}") + print("Falling back to empty dataset") + return Dataset.from_list([]) + + def convert_math_dataset(self, dataset: Dataset) -> Dataset: + """Convert OpenR1-Math-220k format to our training format""" + def format_math_example(example): + # OpenR1-Math-220k format has different column names + # Try to find the right columns + input_text = None + output_text = None + + # Common column names in math datasets + if "question" in example: + input_text = example["question"] + elif "problem" in example: + input_text = example["problem"] + elif "input" in example: + input_text = example["input"] + elif "query" in example: + input_text = example["query"] + + if "answer" in example: + output_text = example["answer"] + elif "solution" in example: + output_text = example["solution"] + elif "output" in example: + output_text = example["output"] + elif "response" in example: + output_text = example["response"] + + # If we can't find the right columns, use the raw example + if input_text is None or output_text is None: + print(f"Warning: Could not parse example columns: {list(example.keys())}") + # Try to use the first two string fields + string_fields = [k for k, v in example.items() if isinstance(v, str) and len(v) > 10] + if len(string_fields) >= 2: + input_text = example[string_fields[0]] + output_text = example[string_fields[1]] + else: + # Skip this example + return None + + # Format with think tags for math reasoning + formatted_output = f"\nLet me solve this step by step.\n\n{output_text}\n\n\n{output_text}" + + return { + "messages": [ + {"role": "user", "content": input_text}, + {"role": "assistant", "content": formatted_output} + ] + } + + # Convert and filter out None results + converted = dataset.map(format_math_example, desc="Converting math dataset") + converted = converted.filter(lambda x: x is not None, desc="Filtering valid examples") + + print(f"Converted {len(converted)} math examples") + if len(converted) > 0: + print(f"First converted example: {converted[0]}") + + return converted + + def convert_mixture_of_thoughts_dataset(self, dataset: Dataset) -> Dataset: + """Convert Mixture-of-Thoughts format to our training format""" + def format_mot_example(example): + # Mixture-of-Thoughts typically has complex reasoning patterns + # Check for common column names in the dataset + input_text = None + output_text = None + + # Try to identify input/output columns + if "prompt" in example: + input_text = example["prompt"] + elif "question" in example: + input_text = example["question"] + elif "input" in example: + input_text = example["input"] + elif "instruction" in example: + input_text = example["instruction"] + + if "response" in example: + output_text = example["response"] + elif "output" in example: + output_text = example["output"] + elif "completion" in example: + output_text = example["completion"] + elif "answer" in example: + output_text = example["answer"] + + # If columns not found, look for thinking patterns + if input_text is None or output_text is None: + # Try to find columns with substantial text + for key, value in example.items(): + if isinstance(value, str) and len(value) > 20: + if input_text is None and any(q in key.lower() for q in ["prompt", "question", "input"]): + input_text = value + elif output_text is None and any(a in key.lower() for a in ["response", "answer", "output"]): + output_text = value + + if input_text is None or output_text is None: + print(f"Warning: Could not parse MoT example columns: {list(example.keys())}") + return None + + # Check if output already contains thinking tags + if "" in output_text or "思考" in output_text: + # Already formatted with thinking + formatted_output = output_text + else: + # Add thinking structure for complex reasoning + formatted_output = f"\nLet me break this down step by step.\n\n{output_text}\n\n\nBased on my analysis, {output_text}" + + return { + "messages": [ + {"role": "user", "content": input_text}, + {"role": "assistant", "content": formatted_output} + ] + } + + # Convert and filter + converted = dataset.map(format_mot_example, desc="Converting Mixture-of-Thoughts dataset") + converted = converted.filter(lambda x: x is not None, desc="Filtering valid examples") + + print(f"Converted {len(converted)} Mixture-of-Thoughts examples") + if len(converted) > 0: + print(f"First converted example: {converted[0]}") + + return converted + + def convert_generic_dataset(self, dataset: Dataset) -> Dataset: + """Convert generic dataset format to our training format""" + def format_generic_example(example): + # Generic conversion for unknown dataset formats + input_text = None + output_text = None + + # Look for any text columns + text_columns = [(k, v) for k, v in example.items() if isinstance(v, str) and len(v) > 10] + + if len(text_columns) >= 2: + # Use first two substantial text columns + input_text = text_columns[0][1] + output_text = text_columns[1][1] + elif len(text_columns) == 1: + # Only one text column - skip this example + return None + else: + return None + + return { + "messages": [ + {"role": "user", "content": input_text}, + {"role": "assistant", "content": output_text} + ] + } + + converted = dataset.map(format_generic_example, desc="Converting generic dataset") + converted = converted.filter(lambda x: x is not None, desc="Filtering valid examples") + + print(f"Converted {len(converted)} generic examples") + return converted + + def format_dataset(self, dataset: Dataset) -> Dataset: + """Format dataset for training""" + print(f"Dataset before formatting: {len(dataset)} examples") + print(f"First example: {dataset[0] if len(dataset) > 0 else 'No data'}") + + # Check if tokenizer has chat template + has_chat_template = ( + hasattr(self.model_wrapper.tokenizer, 'chat_template') and + self.model_wrapper.tokenizer.chat_template is not None + ) + + if not has_chat_template: + print("No chat template found, setting default Gemma chat template") + # Set a simple chat template for Gemma + self.model_wrapper.tokenizer.chat_template = "{% for message in messages %}{{ message['role'] }}\n{{ message['content'] }}\n{% endfor %}" + + def format_chat(example): + # Try to use chat template if available + if has_chat_template or self.model_wrapper.tokenizer.chat_template: + try: + text = self.model_wrapper.tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + add_generation_prompt=False + ) + return {"text": text} + except Exception as e: + print(f"Chat template failed: {e}, using fallback") + + # Fallback: create simple formatted text + if "messages" in example: + user_msg = example["messages"][0]["content"] + assistant_msg = example["messages"][1]["content"] + return {"text": f"user\n{user_msg}\nmodel\n{assistant_msg}\n"} + elif "input" in example and "output" in example: + return {"text": f"user\n{example['input']}\nmodel\n{example['output']}\n"} + else: + return {"text": str(example)} + + # Format dataset + formatted = dataset.map(format_chat, batched=False, desc="Formatting dataset") + print(f"Dataset after formatting: {len(formatted)} examples") + if len(formatted) > 0: + print(f"Columns: {formatted.column_names}") + print(f"First formatted example: {formatted[0]}") + + # Keep only the 'text' column for SFTTrainer + if 'text' in formatted.column_names: + columns_to_remove = [col for col in formatted.column_names if col != 'text'] + if columns_to_remove: + formatted = formatted.remove_columns(columns_to_remove) + + return formatted + + def filter_by_length(self, dataset: Dataset, max_length: int) -> Dataset: + """Filter dataset by sequence length""" + def is_valid_length(example): + # Tokenize and check length + tokens = self.model_wrapper.tokenizer( + example["text"], + truncation=False, + return_length=True + ) + return len(tokens["input_ids"]) <= max_length + + filtered = dataset.filter(is_valid_length, desc="Filtering by length") + print(f"Filtered dataset: {len(filtered)} examples (max_length={max_length})") + return filtered + + def train_stage(self, stage_name: str, stage_config: dict): + """Train a single stage""" + print(f"\n{'='*50}") + print(f"Training stage: {stage_name}") + print(f"Description: {stage_config['description']}") + print(f"{'='*50}\n") + + # Add adapter + self.model_wrapper.add_progressive_adapter(stage_name, stage_config) + + # Load and format dataset + dataset = self.load_dataset(stage_config["dataset_path"], stage_config) + dataset = self.format_dataset(dataset) + + # Filter by sequence length if specified + if "max_length" in stage_config["training"]: + dataset = self.filter_by_length(dataset, stage_config["training"]["max_length"]) + + print(f"Final dataset size: {len(dataset)} examples") + + # Training arguments - with CPU offload optimizations + training_args = TrainingArguments( + output_dir=f"./outputs/checkpoints/{stage_name}", + num_train_epochs=stage_config["training"]["num_epochs"], + per_device_train_batch_size=stage_config["training"]["per_device_batch_size"], + gradient_accumulation_steps=stage_config["training"]["gradient_accumulation_steps"], + learning_rate=float(stage_config["training"]["learning_rate"]), # Ensure it's a float + warmup_steps=stage_config["training"]["warmup_steps"], + logging_steps=stage_config["training"].get("logging_steps", 10), + save_strategy="epoch", + eval_strategy="no", + bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(), + gradient_checkpointing=self.config["model"].get("gradient_checkpointing", False), + max_grad_norm=stage_config["training"].get("max_grad_norm", 1.0), + report_to="wandb" if self.config["experiment"]["use_wandb"] else "none", + run_name=f"{self.config['experiment']['name']}_{stage_name}", + dataloader_pin_memory=False, # Reduce memory usage + remove_unused_columns=False, # Keep all columns + optim=stage_config["training"].get("optim", "adamw_torch"), # Support 8-bit optimizers + dataloader_num_workers=stage_config["training"].get("dataloader_num_workers", 2), + ) + + # Print dataset info for debugging + print(f"Dataset columns: {dataset.column_names}") + print(f"Dataset first example: {dataset[0]}") + + # Ensure model is in training mode before creating trainer + self.model_wrapper.model.train() + + # Final check of trainable parameters + trainable_params = sum(p.numel() for p in self.model_wrapper.model.parameters() if p.requires_grad) + print(f"Final check - Trainable parameters: {trainable_params:,}") + + # Create trainer with minimal configuration + try: + trainer = SFTTrainer( + model=self.model_wrapper.model, + processing_class=self.model_wrapper.tokenizer, + train_dataset=dataset, + args=training_args, + packing=False, # Disable packing for better gradient flow + ) + except Exception as e: + print(f"Error creating SFTTrainer: {e}") + print("Trying with basic configuration...") + trainer = SFTTrainer( + model=self.model_wrapper.model, + processing_class=self.model_wrapper.tokenizer, + train_dataset=dataset, + args=training_args, + ) + + # Train + trainer.train() + + # Save adapter + self.model_wrapper.save_adapter(stage_name) + + # Record history + self.training_history.append({ + "stage": stage_name, + "config": stage_config, + "metrics": trainer.state.log_history + }) + + print(f"\nCompleted training stage: {stage_name}") + + def run_progressive_training(self): + """Run all training stages progressively""" + stages = self.config["progressive_stages"] + + for stage_config in stages: + stage_name = stage_config["name"] + self.train_stage(stage_name, stage_config) + + # Save training history + history_path = Path(self.config["experiment"]["output_dir"]) / "training_history.json" + with open(history_path, "w") as f: + json.dump(self.training_history, f, indent=2) + + print(f"\nAll stages completed! Training history saved to: {history_path}") \ No newline at end of file diff --git a/test_data_load.py b/test_data_load.py new file mode 100644 index 0000000..e16e530 --- /dev/null +++ b/test_data_load.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +"""Test data loading""" + +import sys +from pathlib import Path +sys.path.append(str(Path(__file__).parent)) + +from src.training import ProgressiveTrainer +from src.progressive_model import ProgressiveReasoningModel +import yaml + +# Load config +with open("config/training_config.yaml") as f: + config = yaml.safe_load(f) + +# Create dummy model wrapper +class DummyModelWrapper: + def __init__(self): + self.tokenizer = None + +model_wrapper = DummyModelWrapper() + +# Create trainer +trainer = ProgressiveTrainer(model_wrapper, config) + +# Test data loading +stage_config = config["progressive_stages"][0] +dataset_path = stage_config["dataset_path"] +print(f"Loading dataset from: {dataset_path}") + +dataset = trainer.load_dataset(dataset_path) +print(f"Loaded {len(dataset)} examples") + +if len(dataset) > 0: + print(f"First example: {dataset[0]}") \ No newline at end of file