def create_model(
model_name: str,
model_cfg: dict = None,
in_channels: int = 3,
num_classes: int = 80,
checkpoint_path: str = "",
**kwargs,
):
model_args = dict(cfg=model_cfg, num_classes=num_classes, in_channels=in_channels)
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if not is_model(model_name):
raise RuntimeError(f"Unknown model {model_name}")
create_fn = model_entrypoint(model_name)
model = create_fn(**model_args, **kwargs)
if checkpoint_path:
assert os.path.isfile(checkpoint_path) and checkpoint_path.endswith(
".ckpt"
), f"[{checkpoint_path}] not a ckpt file."
checkpoint_param = load_checkpoint(checkpoint_path)
load_param_into_net(model, checkpoint_param)
logger.info(f"Load checkpoint from [{checkpoint_path}] success.")
return model