|
566 | 566 | }
|
567 | 567 | ],
|
568 | 568 | "source": [
|
569 |
| - "class Rectangle():\n", |
570 |
| - " def __init__(self, w, h, center):\n", |
571 |
| - " self.w = w\n", |
572 |
| - " self.h = h\n", |
| 569 | + "class Circle():\n", |
| 570 | + " def __init__(self, center):\n", |
573 | 571 | " self.center = center\n",
|
574 | 572 | "\n",
|
575 | 573 | " def __call__(self, size):\n",
|
576 | 574 | " n = torch.randn([size, 2])\n",
|
577 | 575 | " n = n / n.norm(dim=-1).unsqueeze(1)\n",
|
578 | 576 | " return n + self.center + 0.1 * torch.randn([size, 2])\n",
|
579 |
| - "# x = self.w * torch.rand([size, 1])\n", |
580 |
| - "# y = self.h * torch.rand([size, 1])\n", |
581 |
| - "# return torch.cat([x, y], dim=1) + self.center\n", |
582 | 577 | "\n",
|
583 | 578 | "\n",
|
584 |
| - "distribution = Rectangle(1, 4, torch.tensor([4, 4]))\n", |
| 579 | + "distribution = Circle(torch.tensor([4, 4]))\n", |
585 | 580 | "ys = distribution(1024)\n",
|
586 | 581 | "\n",
|
587 | 582 | "nf = NormalizingFlow(\n",
|
|
0 commit comments