Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,27 @@ checkpoints/
.idea/
wandb/
*.pth
debug*

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
2 changes: 1 addition & 1 deletion datasets_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform
from datasets import load_dataset
from debug import DEBUG

class INatDataset(ImageFolder):
Expand Down Expand Up @@ -65,6 +64,7 @@ def build_dataset(is_train, data_path, args):
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
elif args.data_set == 'HUGGINGFACE':
from datasets import load_dataset
def huggingface_transform(examples):
examples["image"] = [transform(x.convert(mode="RGB")) for x in examples["image"]]
return examples
Expand Down
1 change: 1 addition & 0 deletions debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DEBUG=False
53 changes: 40 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,15 @@ def get_args_parser():
help='use scale-aware embeds')
parser.add_argument('--grid-to-random-ratio', default=0.7, type=float, help='hybrid sampler grid to random ratio')

parser.add_argument('--model-type', default='deit', type=str, choices=["deit", "swin", "pvt"])


return parser


def main(args):
if args.model_type != "deit":
assert args.eval, "For non-deit models, only evaluation is implemented."
utils.init_distributed_mode(args)

print(args)
Expand Down Expand Up @@ -356,19 +361,39 @@ def log_to_wandb(log_dict, step):
mixup_fn = None

print(f"Creating model: {args.model}")
model = create_model(
args.model,
pretrained=False,
num_classes=args.nb_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
drop_block_rate=None,
img_size=args.input_size,

Patch_layer=PatchEmbedHybrid,
use_learned_pos_embed = args.use_learned_pos_embed,
quantize_pos_embed = args.quantize_pos_embed
)
if args.model_type == "deit":
model = create_model(
args.model,
pretrained=False,
num_classes=args.nb_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
drop_block_rate=None,
img_size=args.input_size,

Patch_layer=PatchEmbedHybrid,
use_learned_pos_embed = args.use_learned_pos_embed,
quantize_pos_embed = args.quantize_pos_embed
)
elif args.model_type == "swin":
from swin.config import get_config
from swin.models import build_model
class _Args:
pass
_args = _Args()
cfg_path = "swin/swin_base_patch4_window7_224.yaml"
setattr(_args, "cfg", cfg_path)
setattr(_args, "opts", [])
setattr(_args, "local_rank", 0)
setattr(_args, "data_path", args.data_path)
config = get_config(_args)
model = build_model(config)
elif args.model_type == "pvt":
assert args.model == 'pvt_v2_b5'
from pvt.pvt_v2 import pvt_v2_b5
model = pvt_v2_b5()
else:
assert False

if args.finetune:
if args.finetune.startswith('https'):
Expand Down Expand Up @@ -490,6 +515,8 @@ def log_to_wandb(log_dict, step):
else:
checkpoint = torch.load(args.resume, map_location='cpu')
#del checkpoint['model']['pos_embed']
if args.model_type == "pvt":
checkpoint = {"model": checkpoint}
print(model_without_ddp.load_state_dict(checkpoint['model'], strict=False))
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
Expand Down
Loading