progressive-llm/flake.nix
Tatsuhiko Akiyama 37f1ad9408 ok
2025-07-10 21:00:42 +09:00

201 lines
6.9 KiB
Nix

{
description = "Progressive LLM Training for LLM2025";
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils";
};
outputs = { self, nixpkgs, flake-utils }:
flake-utils.lib.eachDefaultSystem (system:
let
pkgs = import nixpkgs {
inherit system;
config = {
allowUnfree = true;
cudaSupport = true;
};
overlays = [
(final: prev: {
python311 = prev.python311.override {
packageOverrides = python-self: python-super: {
# Disable tests for problematic packages
pytest-doctestplus = python-super.pytest-doctestplus.overrideAttrs (oldAttrs: {
doCheck = false;
doInstallCheck = false;
pytestCheckPhase = "echo 'Skipping tests'";
});
# Also disable tests for jupyter-related packages if they cause issues
jupyter = python-super.jupyter.overrideAttrs (oldAttrs: {
doCheck = false;
doInstallCheck = false;
});
notebook = python-super.notebook.overrideAttrs (oldAttrs: {
doCheck = false;
doInstallCheck = false;
});
# Disable tests for psycopg and psycopg2
psycopg = python-super.psycopg.overrideAttrs (oldAttrs: {
doCheck = false;
doInstallCheck = false;
pytestCheckPhase = "echo 'Skipping tests'";
pythonImportsCheck = []; # Disable import checks
});
psycopg2 = python-super.psycopg2.overrideAttrs (oldAttrs: {
doCheck = false;
doInstallCheck = false;
pytestCheckPhase = "echo 'Skipping tests'";
pythonImportsCheck = []; # Disable import checks
});
# Disable tests for sqlframe
sqlframe = python-super.sqlframe.overrideAttrs (oldAttrs: {
doCheck = false;
doInstallCheck = false;
pytestCheckPhase = "echo 'Skipping tests'";
pythonImportsCheck = []; # Disable import checks
});
# Disable tests for accelerate
accelerate = python-super.accelerate.overrideAttrs (oldAttrs: {
doCheck = false;
doInstallCheck = false;
pytestCheckPhase = "echo 'Skipping tests'";
pythonImportsCheck = []; # Disable import checks
});
curl-cffi = python-super.curl-cffi.overrideAttrs (oldAttrs: {
doCheck = false;
doInstallCheck = false;
pytestCheckPhase = "echo 'Skipping tests'";
pythonImportsCheck = []; # Disable import checks
});
};
};
})
];
};
# Python 3.11 for better compatibility
python = pkgs.python311;
# Python packages
pythonWithPackages = python.withPackages (ps: with ps; [
# Core ML packages
torch
torchvision
torchaudio
transformers
accelerate
datasets
tokenizers
scikit-learn
# Required dependencies from requirements.txt
pyyaml
jsonlines
sentencepiece
protobuf
# Additional useful packages
numpy
scipy
matplotlib
jupyter
notebook
ipython
pandas
rich # For TUI
# Development tools
black
flake8
pytest
mypy
# Build tools
pip
setuptools
wheel
# LLM specific packages
peft
trl
bitsandbytes
wandb
]);
in
{
devShells.default = pkgs.mkShell {
buildInputs = with pkgs; [
# Python with packages
pythonWithPackages
# Build tools
gcc
cmake
ninja
pkg-config
# Git
git
git-lfs
# Development tools
htop
tmux
vim
# Libraries needed for Python packages
openssl
zlib
glib
stdenv.cc.cc.lib
# CUDA support
cudaPackages.cudatoolkit
cudaPackages.cudnn
];
shellHook = ''
echo "🚀 Progressive LLM Training Environment"
echo "Python version: $(python --version)"
echo "PyTorch version: $(python -c 'import torch; print(torch.__version__)')"
echo "CUDA available: $(python -c 'import torch; print(torch.cuda.is_available())')"
# Set up CUDA environment
export CUDA_HOME=${pkgs.cudaPackages.cudatoolkit}
export CUDA_PATH=${pkgs.cudaPackages.cudatoolkit}
export LD_LIBRARY_PATH=${pkgs.cudaPackages.cudatoolkit}/lib:${pkgs.cudaPackages.cudnn}/lib:${pkgs.stdenv.cc.cc.lib}/lib:$LD_LIBRARY_PATH
# Set Python path
export PYTHONPATH=$PWD/src:$PYTHONPATH
echo ""
echo "Available commands:"
echo " python scripts/train_progressive.py # Start training"
echo " python scripts/evaluate.py # Evaluate model"
echo " jupyter notebook # Start Jupyter"
echo ""
# Create data directory if not exists
mkdir -p data
# Prepare sample data if not exists
if [ ! -f "data/basic_cot/train.jsonl" ]; then
echo "Preparing sample datasets..."
python -c "from src.data_utils import prepare_sample_datasets; prepare_sample_datasets()" || echo "Sample data preparation skipped"
fi
# Note about flash-attn
echo "Note: flash-attn is not included in nixpkgs. If needed, install manually with:"
echo " pip install flash-attn --no-build-isolation"
'';
# Environment variables
CUDA_HOME = "${pkgs.cudaPackages.cudatoolkit}";
CUDA_PATH = "${pkgs.cudaPackages.cudatoolkit}";
NIX_SHELL_PRESERVE_PROMPT = 1;
LOCALE_ARCHIVE = "${pkgs.glibcLocales}/lib/locale/locale-archive";
LC_ALL = "en_US.UTF-8";
};
});
}