aibedo.interface.get_model
- aibedo.interface.get_model(config: omegaconf.dictconfig.DictConfig, **kwargs) aibedo.models.base_model.BaseModel[source]
Get the AIBEDO model, a subclass of
BaseModel, as defined by the key value pairs inconfig.model.- Parameters
config (DictConfig) – A OmegaConf config (e.g. produced by hydra <config>.yaml file parsing)
**kwargs – Any additional keyword arguments for the model class (overrides any key in config, if present)
- Returns
BaseModel – The model that you can directly use to train with pytorch-lightning
Examples:
from aibedo.utilities.config_utils import get_config_from_hydra_compose_overrides config_mlp = get_config_from_hydra_compose_overrides(overrides=['model=mlp']) mlp_model = get_model(config_mlp) # Get a prediction for a (B, S, C) shaped input random_mlp_input = torch.randn(1, 100, 5) random_prediction = mlp_model.predict(random_mlp_input)