Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dataset_creation/real_builders/raise_image_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import OrderedDict
from functools import lru_cache
import json
import os
from pathlib import Path
from typing import Iterable, List, Optional
import warnings
Expand All @@ -26,6 +27,8 @@ def __init__(
self.root_path = Path(root_path)
self.convert_to_jpeg = convert_to_jpeg
self.tmp_cache_dir = Path(tmp_cache_dir) if tmp_cache_dir is not None else None
if not self.root_path.exists():
raise FileNotFoundError(f"Root path does not exist: {self.root_path}")

def get_prefix(self) -> str:
return "RAISE"
Expand Down
2 changes: 1 addition & 1 deletion requirements_train_and_evaluation.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ruamel.yaml
img2dataset
click
torch==2.7.0
torchvision=0.22.0
torchvision==0.22.0
xformers
tensorboard
lightning[pytorch-extra]
Expand Down
15 changes: 9 additions & 6 deletions training_and_evaluation/algorithms/models/openai_clip_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,14 @@ def forward_features(self, x: Tensor) -> Tensor:
def forward_head(self, x: Tensor) -> Tensor:
return self.fc(x)


ModelFactoryRegistry().register_model_factory(
"openai_clip_image", make_openai_clip_image_model
)

import clip
for arch in clip.available_models():
ModelFactoryRegistry().register_model_factory(
f"{arch}_tune", make_openai_clip_image_model
)
ModelFactoryRegistry().register_model_factory(
f"{arch}_probe", make_openai_clip_image_model
)

__all__ = [
"make_openai_clip_image_model",
Expand All @@ -121,5 +124,5 @@ def forward_head(self, x: Tensor) -> Tensor:
if __name__ == "__main__":
model = make_openai_clip_image_model("RN50_tune", num_classes=1)
print(model)
model = make_openai_clip_image_model("RN50_probe", num_classes=1)
model = make_openai_clip_image_model("ViT-L/14_probe", num_classes=1)
print(model)
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ def setup_val(self):
real_indices = [i for i, gen in enumerate(generator_labels) if gen == ""]
fake_indices = [i for i, gen in enumerate(generator_labels) if gen != ""]

selected_real_images: List[int] = []
selected_fake_images: List[int] = []

# Balance images (stratified)
if n_real > n_fake:
n_to_select = n_fake
Expand All @@ -382,6 +385,9 @@ def setup_val(self):
stratify=[generator_labels[i] for i in fake_indices],
random_state=self.data_management_seed + 1,
)
else:
selected_real_images = real_indices
selected_fake_images = fake_indices

subset_indices = selected_real_images + selected_fake_images
validation_dataset = validation_dataset.select(subset_indices)
Expand Down Expand Up @@ -495,6 +501,9 @@ def setup_test(self):
real_indices = [i for i, gen in enumerate(generator_labels) if gen == ""]
fake_indices = [i for i, gen in enumerate(generator_labels) if gen != ""]

selected_real_images: List[int] = []
selected_fake_images: List[int] = []

# Balance images (stratified)
if n_real > n_fake:
n_to_select = n_fake
Expand All @@ -512,6 +521,9 @@ def setup_test(self):
stratify=[generator_labels[i] for i in fake_indices],
random_state=self.data_management_seed + 1,
)
else:
selected_real_images = real_indices
selected_fake_images = fake_indices

subset_indices = selected_real_images + selected_fake_images
test_dataset = test_dataset.select(subset_indices)
Expand Down