1+ from functools import cached_property
12from typing import Any , Optional , Type , Union
23
34from pydantic .fields import FieldInfo
45
56import dspy
6- from dspy .clients .base_lm import BaseLM
77from dspy .primitives .module import Module
88from dspy .signatures .field import OutputField
99from dspy .signatures .signature import Signature , ensure_signature
@@ -27,43 +27,60 @@ def __init__(
2727 **config: The configuration for the module.
2828 """
2929 super ().__init__ ()
30- signature = ensure_signature (signature )
31- self .predict_reasoning = dspy .Predict (signature , ** config )
32- prefix = "Reasoning: Let's think step by step in order to"
33- desc = "${reasoning}"
34- rationale_field_type = rationale_field .annotation if rationale_field else rationale_field_type
35- rationale_field = rationale_field if rationale_field else dspy .OutputField (prefix = prefix , desc = desc )
36- extended_signature = signature .prepend (name = "reasoning" , field = rationale_field , type_ = rationale_field_type )
37- self .predict = dspy .Predict (extended_signature , ** config )
30+ self ._signature = ensure_signature (signature )
31+ self ._config = config
32+ self ._rationale_field = rationale_field
33+ self ._rationale_field_type = rationale_field_type
3834
39- def _validate_lm (self , ** kwargs ):
40- """Helper method to validate the LM configuration."""
41- lm = kwargs .get ("lm" , dspy .settings .lm )
42-
43- if lm is None :
44- raise ValueError (
45- "No LM is loaded. Please configure the LM using `dspy.configure(lm=dspy.LM(...))`. e.g, "
46- "`dspy.configure(lm=dspy.LM('openai/gpt-4o-mini'))`"
35+ @cached_property
36+ def predict (self ):
37+ """Returns the appropriate predict instance based on the LM's reasoning model capability."""
38+ lm = dspy .settings .lm
39+ if lm and getattr (lm , "reasoning_model" , False ):
40+ return dspy .Predict (self ._signature , ** self ._config )
41+ else :
42+ prefix = "Reasoning: Let's think step by step in order to"
43+ desc = "${reasoning}"
44+ rationale_field_type = (
45+ self ._rationale_field .annotation if self ._rationale_field else self ._rationale_field_type
4746 )
48-
49- if isinstance (lm , str ):
50- # Many users mistakenly use `dspy.configure(lm="openai/gpt-4o-mini")` instead of
51- # `dspy.configure(lm=dspy.LM("openai/gpt-4o-mini"))`, so we are providing a specific error message.
52- raise ValueError (
53- f"LM must be an instance of `dspy.BaseLM`, not a string. Instead of using a string like "
54- f"'dspy.configure(lm=\" { lm } \" )', please configure the LM like 'dspy.configure(lm=dspy.LM(\" { lm } \" ))'"
47+ rationale_field = (
48+ self ._rationale_field if self ._rationale_field else dspy .OutputField (prefix = prefix , desc = desc )
5549 )
56- elif not isinstance ( lm , BaseLM ):
57- raise ValueError ( f"LM must be an instance of `dspy.BaseLM`, not { type ( lm ) } . Received `lm= { lm } `." )
58-
59- return lm
50+ extended_signature = self . _signature . prepend (
51+ name = "reasoning" , field = rationale_field , type_ = rationale_field_type
52+ )
53+ return dspy . Predict ( extended_signature , ** self . _config )
6054
6155 def forward (self , ** kwargs ):
62- lm = self ._validate_lm (** kwargs )
63- return self .predict (** kwargs ) if not lm .reasoning_model else self .predict_reasoning (** kwargs )
56+ return self .predict (** kwargs )
6457
6558 async def aforward (self , ** kwargs ):
66- lm = self ._validate_lm (** kwargs )
67- return await (
68- self .predict .acall (** kwargs ) if not lm .reasoning_model else self .predict_reasoning .acall (** kwargs )
69- )
59+ return await self .predict .acall (** kwargs )
60+
61+ def load_state (self , state ):
62+ """Override to ensure predict parameter is created before loading state."""
63+ # If predict state exists but predict hasn't been accessed yet, access it first
64+ if "predict" in state and "predict" not in self .__dict__ :
65+ _ = self .predict # This creates the predict instance
66+
67+ # Now call the base load_state which will load into all named_parameters
68+ return super ().load_state (state )
69+
70+ def __setstate__ (self , state ):
71+ """Custom deserialization for cloudpickle to preserve predict instance."""
72+ # Restore the state normally
73+ self .__dict__ .update (state )
74+
75+ # If predict was cached and serialized, we don't need to do anything special
76+ # since cloudpickle should have preserved it correctly
77+
78+ def __getstate__ (self ):
79+ """Custom serialization for cloudpickle to ensure predict instance is preserved."""
80+ state = self .__dict__ .copy ()
81+ # Force evaluation of cached property if not already done
82+ if "predict" not in state :
83+ # Access the predict property to cache it before serialization
84+ _ = self .predict
85+ state = self .__dict__ .copy ()
86+ return state
0 commit comments