diff --git a/app.py b/app.py index e9bb1fd..2ae4cfb 100644 --- a/app.py +++ b/app.py @@ -285,6 +285,9 @@ def generate_examples(id_image, control_image, prompt_text, seed, enable_realism """ ) +if torch.backends.mps.is_available(): + torch.set_default_device("mps:0") + download_models() prepare_pipeline(model_version=ModelVersion.DEFAULT_VERSION, enable_realism=ENABLE_REALISM_DEFAULT, enable_anti_blur=ENABLE_ANTI_BLUR_DEFAULT) diff --git a/pipelines/pipeline_infu_flux.py b/pipelines/pipeline_infu_flux.py index bbc42e3..25d666d 100644 --- a/pipelines/pipeline_infu_flux.py +++ b/pipelines/pipeline_infu_flux.py @@ -21,7 +21,8 @@ import numpy as np import torch from diffusers.models import FluxControlNetModel -from facexlib.recognition import init_recognition_model +from facexlib.recognition import init_recognition_model, Backbone +from facexlib.utils import load_file_from_url from huggingface_hub import snapshot_download from insightface.app import FaceAnalysis from insightface.utils import face_align @@ -30,6 +31,13 @@ from .pipeline_flux_infusenet import FluxInfuseNetPipeline from .resampler import Resampler +def get_device(): + """Get the appropriate device for the current system.""" + if torch.cuda.is_available(): + return "cuda" + elif torch.backends.mps.is_available(): + return "mps" + return "cpu" def seed_everything(seed, deterministic=False): """Set random seed. @@ -44,8 +52,12 @@ def seed_everything(seed, deterministic=False): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) + device = get_device() + if device == "cuda": + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + elif device == "mps": + torch.mps.manual_seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) if deterministic: torch.backends.cudnn.deterministic = True @@ -82,14 +94,31 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2 return out_img_pil +def init_arcface_model(device='cpu', model_rootpath=None): + """Initialize ArcFace model with proper device handling.""" + model = Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se') + model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/recognition_arcface_ir_se50.pth' + + model_path = load_file_from_url( + url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) + # Load state dict with proper device mapping + state_dict = torch.load(model_path, map_location=torch.device(device)) + model.load_state_dict(state_dict, strict=True) + model.eval() + model = model.to(device) + return model + + def extract_arcface_bgr_embedding(in_image, landmark, arcface_model=None, in_settings=None): kps = landmark arc_face_image = face_align.norm_crop(in_image, landmark=np.array(kps), image_size=112) arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0,3,1,2) / 255. arc_face_image = 2 * arc_face_image - 1 - arc_face_image = arc_face_image.cuda().contiguous() + arc_face_image = arc_face_image.contiguous() if arcface_model is None: - arcface_model = init_recognition_model('arcface', device='cuda') + arcface_model = init_arcface_model(device=get_device()) + # Move input to the same device as the model + arc_face_image = arc_face_image.to(next(arcface_model.parameters()).device) face_emb = arcface_model(arc_face_image)[0] # [512], normalized return face_emb @@ -131,33 +160,33 @@ def __init__( infu_flux_version='v1.0', model_version='aes_stage2', ): - + self.device = get_device() self.infu_flux_version = infu_flux_version self.model_version = model_version # Load pipeline try: infusenet_path = os.path.join(infu_model_path, 'InfuseNetModel') - self.infusenet = FluxControlNetModel.from_pretrained(infusenet_path, torch_dtype=torch.bfloat16) + self.infusenet = FluxControlNetModel.from_pretrained(infusenet_path, torch_dtype=torch.float32) except: print("No InfiniteYou model found. Downloading from HuggingFace `ByteDance/InfiniteYou` to `./models/InfiniteYou` ...") snapshot_download(repo_id='ByteDance/InfiniteYou', local_dir='./models/InfiniteYou', local_dir_use_symlinks=False) infu_model_path = os.path.join('./models/InfiniteYou', f'infu_flux_{infu_flux_version}', model_version) infusenet_path = os.path.join(infu_model_path, 'InfuseNetModel') - self.infusenet = FluxControlNetModel.from_pretrained(infusenet_path, torch_dtype=torch.bfloat16) + self.infusenet = FluxControlNetModel.from_pretrained(infusenet_path, torch_dtype=torch.float32) insightface_root_path = './models/InfiniteYou/supports/insightface' try: pipe = FluxInfuseNetPipeline.from_pretrained( base_model_path, controlnet=self.infusenet, - torch_dtype=torch.bfloat16, + torch_dtype=torch.float32, ) except: try: pipe = FluxInfuseNetPipeline.from_single_file( base_model_path, controlnet=self.infusenet, - torch_dtype=torch.bfloat16, + torch_dtype=torch.float32, ) except Exception as e: print(e) @@ -166,8 +195,8 @@ def __init__( 'Then, use `huggingface-cli login` and your access tokens at https://huggingface.co/settings/tokens to authenticate. ' 'After that, run the code again. If you have downloaded it, please use `base_model_path` to specify the correct path.') print('\nIf you are using other models, please download them to a local directory and use `base_model_path` to specify the correct path.') - exit() - pipe.to('cuda', torch.bfloat16) + raise e + pipe.to(self.device) self.pipe = pipe # Load image proj model @@ -184,28 +213,29 @@ def __init__( ff_mult=4, ) image_proj_model_path = os.path.join(infu_model_path, 'image_proj_model.bin') - ipm_state_dict = torch.load(image_proj_model_path, map_location="cpu") + ipm_state_dict = torch.load(image_proj_model_path, map_location=torch.device(self.device)) image_proj_model.load_state_dict(ipm_state_dict['image_proj']) del ipm_state_dict - image_proj_model.to('cuda', torch.bfloat16) + image_proj_model.to(self.device) image_proj_model.eval() self.image_proj_model = image_proj_model # Load face encoder + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider'] self.app_640 = FaceAnalysis(name='antelopev2', - root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + root=insightface_root_path, providers=providers) self.app_640.prepare(ctx_id=0, det_size=(640, 640)) self.app_320 = FaceAnalysis(name='antelopev2', - root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + root=insightface_root_path, providers=providers) self.app_320.prepare(ctx_id=0, det_size=(320, 320)) self.app_160 = FaceAnalysis(name='antelopev2', - root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + root=insightface_root_path, providers=providers) self.app_160.prepare(ctx_id=0, det_size=(160, 160)) - self.arcface_model = init_recognition_model('arcface', device='cuda') + self.arcface_model = init_arcface_model(device=self.device) def load_loras(self, loras): names, scales = [],[] @@ -255,15 +285,15 @@ def __call__( face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face landmark = face_info['kps'] id_embed = extract_arcface_bgr_embedding(id_image_cv2, landmark, self.arcface_model) - id_embed = id_embed.clone().unsqueeze(0).float().cuda() + id_embed = id_embed.clone().unsqueeze(0).float() id_embed = id_embed.reshape([1, -1, 512]) - id_embed = id_embed.to(device='cuda', dtype=torch.bfloat16) + id_embed = id_embed.to(device=self.device) with torch.no_grad(): id_embed = self.image_proj_model(id_embed) bs_embed, seq_len, _ = id_embed.shape id_embed = id_embed.repeat(1, 1, 1) id_embed = id_embed.view(bs_embed * 1, seq_len, -1) - id_embed = id_embed.to(device='cuda', dtype=torch.bfloat16) + id_embed = id_embed.to(device=self.device) # Load control image print('Preparing the control image') diff --git a/test.py b/test.py index 1a76c9e..9f9308f 100644 --- a/test.py +++ b/test.py @@ -48,7 +48,11 @@ def main(): assert args.model_version in ['aes_stage2', 'sim_stage1'], 'Currently only supports model versions: aes_stage2 | sim_stage1' # Set cuda device - torch.cuda.set_device(args.cuda_device) + if torch.cuda.is_available(): + torch.cuda.set_device(args.cuda_device) + elif torch.backends.mps.is_available(): + torch.set_default_device("mps:0") + print(f'Using cuda device: {torch.empty(1).device}') # Load pipeline infu_model_path = os.path.join(args.model_dir, f'infu_flux_{args.infu_flux_version}', args.model_version)