Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ def generate_examples(id_image, control_image, prompt_text, seed, enable_realism
"""
)

if torch.backends.mps.is_available():
torch.set_default_device("mps:0")

download_models()

prepare_pipeline(model_version=ModelVersion.DEFAULT_VERSION, enable_realism=ENABLE_REALISM_DEFAULT, enable_anti_blur=ENABLE_ANTI_BLUR_DEFAULT)
Expand Down
72 changes: 51 additions & 21 deletions pipelines/pipeline_infu_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import numpy as np
import torch
from diffusers.models import FluxControlNetModel
from facexlib.recognition import init_recognition_model
from facexlib.recognition import init_recognition_model, Backbone
from facexlib.utils import load_file_from_url
from huggingface_hub import snapshot_download
from insightface.app import FaceAnalysis
from insightface.utils import face_align
Expand All @@ -30,6 +31,13 @@
from .pipeline_flux_infusenet import FluxInfuseNetPipeline
from .resampler import Resampler

def get_device():
"""Get the appropriate device for the current system."""
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
return "cpu"

def seed_everything(seed, deterministic=False):
"""Set random seed.
Expand All @@ -44,8 +52,12 @@ def seed_everything(seed, deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
device = get_device()
if device == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
elif device == "mps":
torch.mps.manual_seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
Expand Down Expand Up @@ -82,14 +94,31 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2
return out_img_pil


def init_arcface_model(device='cpu', model_rootpath=None):
"""Initialize ArcFace model with proper device handling."""
model = Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se')
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/recognition_arcface_ir_se50.pth'

model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
# Load state dict with proper device mapping
state_dict = torch.load(model_path, map_location=torch.device(device))
model.load_state_dict(state_dict, strict=True)
model.eval()
model = model.to(device)
return model


def extract_arcface_bgr_embedding(in_image, landmark, arcface_model=None, in_settings=None):
kps = landmark
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(kps), image_size=112)
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0,3,1,2) / 255.
arc_face_image = 2 * arc_face_image - 1
arc_face_image = arc_face_image.cuda().contiguous()
arc_face_image = arc_face_image.contiguous()
if arcface_model is None:
arcface_model = init_recognition_model('arcface', device='cuda')
arcface_model = init_arcface_model(device=get_device())
# Move input to the same device as the model
arc_face_image = arc_face_image.to(next(arcface_model.parameters()).device)
face_emb = arcface_model(arc_face_image)[0] # [512], normalized
return face_emb

Expand Down Expand Up @@ -131,33 +160,33 @@ def __init__(
infu_flux_version='v1.0',
model_version='aes_stage2',
):

self.device = get_device()
self.infu_flux_version = infu_flux_version
self.model_version = model_version

# Load pipeline
try:
infusenet_path = os.path.join(infu_model_path, 'InfuseNetModel')
self.infusenet = FluxControlNetModel.from_pretrained(infusenet_path, torch_dtype=torch.bfloat16)
self.infusenet = FluxControlNetModel.from_pretrained(infusenet_path, torch_dtype=torch.float32)
except:
print("No InfiniteYou model found. Downloading from HuggingFace `ByteDance/InfiniteYou` to `./models/InfiniteYou` ...")
snapshot_download(repo_id='ByteDance/InfiniteYou', local_dir='./models/InfiniteYou', local_dir_use_symlinks=False)
infu_model_path = os.path.join('./models/InfiniteYou', f'infu_flux_{infu_flux_version}', model_version)
infusenet_path = os.path.join(infu_model_path, 'InfuseNetModel')
self.infusenet = FluxControlNetModel.from_pretrained(infusenet_path, torch_dtype=torch.bfloat16)
self.infusenet = FluxControlNetModel.from_pretrained(infusenet_path, torch_dtype=torch.float32)
insightface_root_path = './models/InfiniteYou/supports/insightface'
try:
pipe = FluxInfuseNetPipeline.from_pretrained(
base_model_path,
controlnet=self.infusenet,
torch_dtype=torch.bfloat16,
torch_dtype=torch.float32,
)
except:
try:
pipe = FluxInfuseNetPipeline.from_single_file(
base_model_path,
controlnet=self.infusenet,
torch_dtype=torch.bfloat16,
torch_dtype=torch.float32,
)
except Exception as e:
print(e)
Expand All @@ -166,8 +195,8 @@ def __init__(
'Then, use `huggingface-cli login` and your access tokens at https://huggingface.co/settings/tokens to authenticate. '
'After that, run the code again. If you have downloaded it, please use `base_model_path` to specify the correct path.')
print('\nIf you are using other models, please download them to a local directory and use `base_model_path` to specify the correct path.')
exit()
pipe.to('cuda', torch.bfloat16)
raise e
pipe.to(self.device)
self.pipe = pipe

# Load image proj model
Expand All @@ -184,28 +213,29 @@ def __init__(
ff_mult=4,
)
image_proj_model_path = os.path.join(infu_model_path, 'image_proj_model.bin')
ipm_state_dict = torch.load(image_proj_model_path, map_location="cpu")
ipm_state_dict = torch.load(image_proj_model_path, map_location=torch.device(self.device))
image_proj_model.load_state_dict(ipm_state_dict['image_proj'])
del ipm_state_dict
image_proj_model.to('cuda', torch.bfloat16)
image_proj_model.to(self.device)
image_proj_model.eval()

self.image_proj_model = image_proj_model

# Load face encoder
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if torch.cuda.is_available() else ['CPUExecutionProvider']
self.app_640 = FaceAnalysis(name='antelopev2',
root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
root=insightface_root_path, providers=providers)
self.app_640.prepare(ctx_id=0, det_size=(640, 640))

self.app_320 = FaceAnalysis(name='antelopev2',
root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
root=insightface_root_path, providers=providers)
self.app_320.prepare(ctx_id=0, det_size=(320, 320))

self.app_160 = FaceAnalysis(name='antelopev2',
root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
root=insightface_root_path, providers=providers)
self.app_160.prepare(ctx_id=0, det_size=(160, 160))

self.arcface_model = init_recognition_model('arcface', device='cuda')
self.arcface_model = init_arcface_model(device=self.device)

def load_loras(self, loras):
names, scales = [],[]
Expand Down Expand Up @@ -255,15 +285,15 @@ def __call__(
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
landmark = face_info['kps']
id_embed = extract_arcface_bgr_embedding(id_image_cv2, landmark, self.arcface_model)
id_embed = id_embed.clone().unsqueeze(0).float().cuda()
id_embed = id_embed.clone().unsqueeze(0).float()
id_embed = id_embed.reshape([1, -1, 512])
id_embed = id_embed.to(device='cuda', dtype=torch.bfloat16)
id_embed = id_embed.to(device=self.device)
with torch.no_grad():
id_embed = self.image_proj_model(id_embed)
bs_embed, seq_len, _ = id_embed.shape
id_embed = id_embed.repeat(1, 1, 1)
id_embed = id_embed.view(bs_embed * 1, seq_len, -1)
id_embed = id_embed.to(device='cuda', dtype=torch.bfloat16)
id_embed = id_embed.to(device=self.device)

# Load control image
print('Preparing the control image')
Expand Down
6 changes: 5 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def main():
assert args.model_version in ['aes_stage2', 'sim_stage1'], 'Currently only supports model versions: aes_stage2 | sim_stage1'

# Set cuda device
torch.cuda.set_device(args.cuda_device)
if torch.cuda.is_available():
torch.cuda.set_device(args.cuda_device)
elif torch.backends.mps.is_available():
torch.set_default_device("mps:0")
print(f'Using cuda device: {torch.empty(1).device}')

# Load pipeline
infu_model_path = os.path.join(args.model_dir, f'infu_flux_{args.infu_flux_version}', args.model_version)
Expand Down