aibedo.interface.get_model

aibedo.interface.get_model(config: omegaconf.dictconfig.DictConfig, **kwargs) aibedo.models.base_model.BaseModel[source]

Get the AIBEDO model, a subclass of BaseModel, as defined by the key value pairs in config.model.

Parameters
  • config (DictConfig) – A OmegaConf config (e.g. produced by hydra <config>.yaml file parsing)

  • **kwargs – Any additional keyword arguments for the model class (overrides any key in config, if present)

Returns

BaseModel – The model that you can directly use to train with pytorch-lightning

Examples:

from aibedo.utilities.config_utils import get_config_from_hydra_compose_overrides

config_mlp = get_config_from_hydra_compose_overrides(overrides=['model=mlp'])
mlp_model = get_model(config_mlp)

# Get a prediction for a (B, S, C) shaped input
random_mlp_input = torch.randn(1, 100, 5)
random_prediction = mlp_model.predict(random_mlp_input)