diff --git a/nanotabpfn/interface.py b/nanotabpfn/interface.py index fa1c3d7..87b4cec 100644 --- a/nanotabpfn/interface.py +++ b/nanotabpfn/interface.py @@ -5,7 +5,7 @@ import requests import torch import torch.nn.functional as F -from pfns.bar_distribution import FullSupportBarDistribution +from pfns.model.bar_distribution import FullSupportBarDistribution from sklearn.compose import ColumnTransformer from sklearn.impute import SimpleImputer from sklearn.pipeline import Pipeline diff --git a/nanotabpfn/train.py b/nanotabpfn/train.py index 0e67e19..e591c50 100644 --- a/nanotabpfn/train.py +++ b/nanotabpfn/train.py @@ -3,7 +3,7 @@ import time from torch.utils.data import DataLoader from typing import Dict -from pfns.bar_distribution import FullSupportBarDistribution +from pfns.model.bar_distribution import FullSupportBarDistribution import schedulefree import os diff --git a/nanotabpfn/utils.py b/nanotabpfn/utils.py index 3a9cbe5..fe879ae 100644 --- a/nanotabpfn/utils.py +++ b/nanotabpfn/utils.py @@ -3,7 +3,7 @@ import torch import numpy as np -from pfns.bar_distribution import get_bucket_limits +from pfns.model.bar_distribution import get_bucket_borders def set_randomness_seed(seed): random.seed(seed) @@ -32,5 +32,5 @@ def make_global_bucket_edges(filename, n_buckets=100, device=get_default_device( raise ValueError(f"Too few target samples ({ys_concat.size}) to compute {n_buckets} buckets.") ys_tensor = torch.tensor(ys_concat, dtype=torch.float32, device=device) - global_bucket_edges = get_bucket_limits(n_buckets, ys=ys_tensor).to(device) + global_bucket_edges = get_bucket_borders(n_buckets, ys=ys_tensor).to(device) return global_bucket_edges diff --git a/pretrain_regression.py b/pretrain_regression.py index 6548d8f..53602ed 100644 --- a/pretrain_regression.py +++ b/pretrain_regression.py @@ -1,7 +1,7 @@ import argparse import torch -from pfns.bar_distribution import FullSupportBarDistribution +from pfns.model.bar_distribution import FullSupportBarDistribution from sklearn.metrics import r2_score from nanotabpfn.callbacks import ConsoleLoggerCallback diff --git a/pyproject.toml b/pyproject.toml index 7b6fa75..53e10ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "requests>=2", "scikit-learn>=1.5", "schedulefree>=1.4", - "pfns==0.3.0", + "pfns==0.4.2", "openml==0.15.1", ]