Source code for pylmrob.control

# SPDX-License-Identifier: GPL-3.0-or-later
"""Control parameters for ``lmrob``.

Mirrors robustbase's ``lmrob.control``. Phase 8 fills in the named-preset
defaults for ``setting in {"KS2011", "KS2014"}``. The values here come
from ``robustbase/R/lmrob.R``.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Literal

PsiFamily = Literal["bisquare", "huber", "hampel", "optimal", "ggw", "lqq", "welsh"]
InitMethod = Literal["auto", "S", "M-S", "L1"]
Setting = Literal["KS2011", "KS2014", "MM"]


# Default tuning constants matching R's robustbase 0.99-7. (See R's
# .Mpsi.tuning.default and .Mchi.tuning.default for the canonical table.)
# Default tuning constants in **R's internal form** (post-``.psi.conv.cc``).
# Source: ``robustbase:::.psi.conv.cc`` applied to the user-facing defaults
# from ``robustbase::.Mpsi.tuning.default()``.
_DEFAULT_TUNING_PSI: dict[str, tuple[float, ...]] = {
    "huber": (1.345,),
    "bisquare": (4.685061,),
    "hampel": (1.5 * 0.9016085, 3.5 * 0.9016085, 8.0 * 0.9016085),
    "optimal": (1.060158,),
    "welsh": (2.11,),
    # ggw: case index 4 = (b=1.5, 95% efficiency).
    "ggw": (4,),
    # lqq: (b, c, s) - converted from user-facing (-0.5, 1.5, 0.95, NA).
    "lqq": (1.4734061, 0.9822707, 1.5),
}

_DEFAULT_TUNING_CHI: dict[str, tuple[float, ...]] = {
    "huber": (0.6745,),
    "bisquare": (1.547645,),
    "hampel": (1.5 * 0.2119163, 3.5 * 0.2119163, 8.0 * 0.2119163),
    "optimal": (0.4047,),
    "welsh": (0.5773502691896258,),  # = 1/sqrt(3); gives E[chi(Z)]=0.5 under Z~N(0,1).
    # ggw: case index 6 = (b=1.5, breakdown=0.5).
    "ggw": (6,),
    # lqq: (b, c, s) - converted from user-facing (-0.5, 1.5, NA, 0.5).
    "lqq": (0.4015457, 0.2676971, 1.5),
}


[docs] @dataclass class Control: """Parameters controlling an ``lmrob`` fit. Defaults follow R's ``lmrob.control(setting="KS2014")``. """ # ``setting=None`` (and ``"MM"``) follow R's plain default: psi="bisquare", # method="MM", cov=".vcov.avar1". ``setting="KS2014"`` switches to # psi="lqq", method="SMDM" (with D-step), cov=".vcov.w". # ``setting="KS2011"`` is psi="lqq", method="MM", cov=".vcov.w". # Pass ``psi`` / ``method`` / ``cov`` explicitly to override. setting: Setting | None = None psi: PsiFamily | None = None tuning_chi: float | tuple[float, ...] | None = None tuning_psi: float | tuple[float, ...] | None = None init: InitMethod = "S" method: str | None = None nResample: int = 500 max_it: int = 50 k_max: int = 200 refine_tol: float = 1e-7 rel_tol: float = 1e-7 solve_tol: float = 1e-7 scale_tol: float = 1e-10 zero_tol: float = 1e-10 best_r_s: int = 2 k_fast_s: int = 1 k_m_s: int = 20 mts: int = 1000 subsampling: Literal["nonsingular", "simple"] = "nonsingular" cov: str | None = None eps_outlier: float | None = None eps_x: float | None = None seed: int | None = None trace_lev: int = 0 # Worker threads for the fast-S resampling loop. # 1 = serial (default; deterministic). 0 = auto (only enables threading # for problems large enough that BLAS dominates Python/GIL overhead). # Any positive integer = explicit worker count. n_workers: int = 1 # When True, draw resampling subsets inside the Cython kernel using # numpy's BitGenerator C API (Floyd's combination algorithm). About # 1.4-2x faster at n<=500 because the rank-check SVD goes away. # Off by default: the draw sequence is not byte-identical with # ``np.random.Generator.choice`` so the basin of attraction can shift # slightly, which changes which fits beyond rtol=1e-3 vs R on some # small-n datasets. Opt in when you want raw speed and can tolerate # the drift; not recommended for reproducibility-sensitive workloads. fast_rng: bool = False # BitGenerator backing the resample RNG. # # - ``"PCG64"`` (default): the modern NumPy default. Fast and # statistically better than MT19937, but produces a different # subset-draw sequence than R. # - ``"MT19937"``: NumPy's Mersenne Twister. Closer to R's family # but the seed-to-state path and the 64-bit-vs-32-bit raw output # still differ from R's ``RNG.c``. # - ``"R"``: drive the resample loop from ``pylmrob.r_set_seed`` and # ``r_sample_noreplace``, which is byte-identical to robustbase's # draw sequence. Implies ``n_workers=1``, ``engine_c=False``, and # ``subsampling="simple"`` (set automatically); R's # ``subsampling="nonsingular"`` LU-pivot path isn't yet ported. rng: Literal["PCG64", "MT19937", "R"] = "PCG64" # Use the monolithic Cython engine (pylmrob._core._lmrob). When True, # fast-S + MM + vcov_avar1 run in one nogil C block with one workspace # allocation. On small n this is 5-10x the default Python path; on # large n ``lmrob()`` auto-falls-back to the threaded default path # (since the monolithic kernel is a single C call that does not # parallelise). The Cython subset-draw is not byte-identical to # ``np.random.choice``: on a few small classical datasets it lands # in a basin where ``vcov_avar1`` is singular; ``lmrob()`` catches # that FloatingPointError and retries with ``engine_c=False`` so # the fit always succeeds. engine_c: bool = True bb: float = 0.5 # consistency constant (target value of mean(chi)) extra: dict[str, object] = field(default_factory=dict) def __post_init__(self) -> None: # Apply named-setting defaults, mirroring R's lmrob.control(setting=...) # for fields the user passed as None. if self.setting in ("KS2014", "KS2011"): if self.psi is None: self.psi = "lqq" if self.method is None: # Both KS2011 and KS2014 use the SMDM pipeline in # robustbase. They differ in tuning constants and # ``cov.corrfact`` defaults. self.method = "SMDM" if self.cov is None: self.cov = ".vcov.w" else: if self.psi is None: self.psi = "bisquare" if self.method is None: self.method = "MM" if self.cov is None: self.cov = ".vcov.avar1" # rng="R" implies a serial, R-call-order resample path. Force # the matching control values now so the rest of the pipeline # doesn't have to special-case combinations that don't make sense. # Both "simple" and "nonsingular" subsampling are now supported # via r_set_seed + r_sample_noreplace / r_subsample_nonsingular, # so we honour whichever the user requested. if self.rng == "R": if self.n_workers != 1: raise ValueError( "rng='R' requires n_workers=1 (R's unif_rand stream is" " sequential); got n_workers=" + str(self.n_workers) ) self.engine_c = False # Fill in default tuning constants matching R. if self.tuning_psi is None: if self.psi not in _DEFAULT_TUNING_PSI: raise ValueError(f"unknown psi family {self.psi!r}") self.tuning_psi = _DEFAULT_TUNING_PSI[self.psi] if self.tuning_chi is None: if self.psi not in _DEFAULT_TUNING_CHI: raise ValueError(f"unknown psi family {self.psi!r}") self.tuning_chi = _DEFAULT_TUNING_CHI[self.psi] # sklearn-compatibility shim. Lets ``GridSearchCV`` reach Control fields # via the standard ``estimator__nested__field`` parameter syntax, e.g. # ``param_grid={"control__nResample": [200, 500, 1000]}``.
[docs] def get_params(self, deep: bool = True) -> dict[str, object]: """Return ``Control``'s public fields as a sklearn-style dict.""" import dataclasses return {f.name: getattr(self, f.name) for f in dataclasses.fields(self)}
[docs] def set_params(self, **params: object) -> Control: """Set ``Control``'s public fields in place, sklearn convention.""" import dataclasses valid = {f.name for f in dataclasses.fields(self)} for key, value in params.items(): if key not in valid: raise ValueError(f"Invalid parameter {key!r} for Control") setattr(self, key, value) # Re-run post-init so derived defaults stay consistent. We invoke # __post_init__ directly because Control is a regular @dataclass. self.__post_init__() return self
[docs] @classmethod def preset(cls, setting: Setting, **overrides: object) -> Control: """Build a Control for a named preset. Settings: - ``"KS2014"``: psi="bisquare" (matches robustbase 0.99-7 default). - ``"KS2011"``: same families with KS2011-specific cov estimator. - ``"MM"``: legacy MM defaults (psi="bisquare"). """ if setting == "KS2014": ctrl = cls(setting="KS2014", psi="bisquare", cov=".vcov.avar1", init="S") elif setting == "KS2011": ctrl = cls(setting="KS2011", psi="bisquare", cov=".vcov.w", init="S") elif setting == "MM": ctrl = cls(setting="MM", psi="bisquare", cov=".vcov.avar1", init="S") else: raise ValueError(f"unknown setting: {setting!r}") for key, value in overrides.items(): if not hasattr(ctrl, key): raise TypeError(f"unknown Control field: {key!r}") setattr(ctrl, key, value) return ctrl