# General imports
import os
import random
import math
import numpy as np
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
from loader import load_blender_data # https://github.com/yenchenlin/nerf-pytorch/blob/master/load_blender.py
# Camera imports
from ctypes.wintypes import HACCEL
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset, DataLoader
# Testing imports
from matplotlib import cm, colors
from mpl_toolkits.mplot3d import Axes3D
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
# Train imports
from torch import nn
import ruamel.yaml as yaml
import argparse
import pickle
# Scene imports
from spherical_harmonics import get_spherical_harmonics # https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/spherical_harmonics.py
c:\Users\Victor\anaconda3\envs\minimal\lib\site-packages\tqdm\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
set_seed(0)
class Camera(Dataset):
    """
    Dataset whose elements are (ray, center, RGBA, pixel info)
    Corresponds to what RGBA value is hit along a ray
    Rays are sampled as we have pose and corresponding image data
    """
    def __init__(self, H, W, focal, poses, imgs):
        """
        Initialize dataset with camera info and pose and image data
        """
        self.H = H
        self.W = W
        self.f = focal
        self.near = 0.1
        self.far = 100.0
        self.poses = poses
        self.iposes = torch.inverse(poses)
        self.imgs = imgs
        self.rays, self.centers = self.get_rays((H, W))
    def get_rays(self, resolution):
        """
        Get rays corresponding to (resolution[0] x resolution[1]) image and camera pose and focal length
        Had incorrect version for a while without realizing, switched to nerf_pytorch's implementation
        """
        W = self.W
        H = self.H
        f = self.f
        i, j = torch.meshgrid(torch.linspace(0, W-1, resolution[0]), torch.linspace(0, H-1, resolution[1]))
        i = i.t()
        j = j.t()
        dirs = torch.stack([(i-W/2)/f, -(j-H/2)/f, -torch.ones_like(i)], -1)
        rays_d = torch.sum(dirs[None,:,:,None,:] * self.poses[:,None,None,:3,:3], -1)
        rays_o = self.poses[:,None,None,:3,-1] + 0 * rays_d
        rays_d = torch.nn.functional.normalize(rays_d, dim=-1)
        rays_d = torch.concat([rays_d, torch.zeros((self.poses.size(0), W, H, 1))], dim=-1)
        rays_o = torch.concat([rays_o, torch.ones((self.poses.size(0), W, H, 1))], dim=-1)
        return rays_d, rays_o
    
    def get_rays_from_points(self, points):
        """
        Obtain normalized rays that point from camera center to point
        """
        points = torch.concat([points, torch.ones_like(points[:,0].unsqueeze(-1))], dim=-1)
        centers = torch.zeros_like(points)
        centers[:,3] = 1.0
        centers = self.poses[:,None,:,:] @ centers[None,:,:,None]
        centers = centers.squeeze(-1)
        rays = torch.nn.functional.normalize(points[None,:,:] - centers, dim=-1)
        return rays, centers
    
    def get_pixels(self, points):
        """
        Given points, project onto image space to find corresponding pixel values (Float)
        """
        rays, centers = self.get_rays_from_points(points)
        pi = torch.inverse(self.poses)
        rays = pi[:,None,:,:] @ rays[:,:,:,None]
        rays = rays.squeeze(-1)
        rays = torch.nn.functional.normalize(rays, dim=-1)
        pixels = ((rays * 2*self.f)/rays[:,:,2].unsqueeze(-1))[:,:,:2]
        pixels[:,:,0] += self.H
        pixels[:,:,1] += self.W
        pixels[:,:,0] *= (self.H - 1)/(2 * self.H)
        pixels[:,:,1] *= (self.W - 1)/(2 * self.W)
        return pixels
    
    def __len__(self):
        return self.poses.size(0) * self.W * self.H
    
    def __getitem__(self, idx):
        """
        Each item is (ray, center, RGBA, pixel location)
        Corresponding dimensions: (3, 3, 4, 3)
        """
        p = idx // (self.W * self.H)
        wh = idx % (self.W * self.H)
        w = wh // self.H
        h = wh % self.H
        return self.rays[p,w,h,:], self.centers[p,w,h,:], self.imgs[p,w,h,:], torch.LongTensor([p, w, h])
class TriangleSoup():
    """
    Model that acts as a radiance field where density is non-zero only at triangles of triangle soup
    Can render in a differentiable manner to learn parameters
    """
    def __init__(self, num_points=1000, nn=10, k=5, camera=None):
        """
        Triangle soup model
        num_points : number of vertices
        nn : number of triangles formed between each vertex and its nearest neighbours
        k : order of spherical harmonics used to parameterize vertex colors
        camera : used to get pixels with camera.get_pixels(...) but that currently has bugs
        """
        self.num_points = num_points
        self.num_new_points = num_points//2
        self.points = torch.rand((self.num_points, 3,)) * 2.4 - 1.2
        self.density = 0.1 * torch.rand((self.num_points,))
        self.k = k
        self.iv = [i*i for i in range(k+1)]
        self.sh_coef = torch.rand((self.num_points,3,self.iv[k],), dtype=torch.cfloat)
        self.tk = 100
        self.nn = nn
        self.nbrs = NearestNeighbors(n_neighbors=self.nn+1, algorithm='ball_tree').fit(self.points.numpy())
        distances, indices = self.nbrs.kneighbors(self.points.numpy())
        indices = torch.LongTensor(indices)
        self.triangles = torch.concat(
            [
                torch.floor(torch.arange(self.nn*num_points)/self.nn).long().unsqueeze(-1),
                indices[:,1:].flatten().unsqueeze(-1),
                torch.roll(indices[:,1:], 1, dims=-1).flatten().unsqueeze(-1)
            ],
            dim=-1
        )
        self.computed_topk = False
        if camera is None:
            self.pixels = None
        else:
            self.pixels = camera.get_pixels(self.points)
        
    def initialize_points(self, camera):
        """
        Experimenting with initializing points in a more clever way
        Unfortunately, camera.get_pixels has a bug
        """
        mask = torch.ones((self.num_points,), dtype=torch.bool)
        while True:
            pixels = torch.floor(camera.get_pixels(self.points[mask,:])).long()
            pixel_mask = torch.logical_or(
                torch.logical_or(pixels[:,:,0] < 0, pixels[:,:,0] >= camera.W),
                torch.logical_or(pixels[:,:,1] < 0, pixels[:,:,1] >= camera.H),
            )
            li = torch.arange(pixels.size(0)).unsqueeze(-1).repeat((1, pixels.size(1)))
            pixels = torch.clip(pixels, 0, camera.H-1).long()
            check = camera.imgs[li,pixels[:,:,0],pixels[:,:,1],3] > 0.01
            mask[mask.clone()] = torch.logical_or(torch.logical_and(check, ~pixel_mask).sum(dim=0) < 1, torch.logical_and(~check, ~pixel_mask).sum(dim=0) > 0)
            if mask.sum() == 0:
                break
            if mask.sum() == 1:
                print(self.points[mask,:])
            self.points[mask] = torch.rand((mask.sum(), 3,)) * 2.4 - 1.2
        self.nbrs = NearestNeighbors(n_neighbors=self.nn+1, algorithm='ball_tree').fit(self.points.numpy())
        distances, indices = self.nbrs.kneighbors(self.points.numpy())
        indices = torch.LongTensor(indices)
        self.triangles = torch.concat(
            [
                torch.floor(torch.arange(self.nn*self.num_points)/self.nn).long().unsqueeze(-1),
                indices[:,1:].flatten().unsqueeze(-1),
                torch.roll(indices[:,1:], 1, dims=-1).flatten().unsqueeze(-1)
            ],
            dim=-1
        )
    def to(self, device):
        """
        Sends model to device
        """
        self.points = self.points.to(device)
        self.density = self.density.to(device)
        self.sh_coef = self.sh_coef.to(device)
        self.triangles = self.triangles.to(device)
        if self.pixels is not None:
            self.pixels = self.pixels.to(device)
    def save(self, filename):
        """
        Saves model to file
        TODO create load method
        """
        np.savez(filename, points=self.points.cpu().detach().numpy(), density=self.density.cpu().detach().numpy(), sh_coef=self.sh_coef.cpu().detach().numpy(), triangles=self.triangles.cpu().detach().numpy())        
    def resample(self):
        """
        Sketchy way of resampling points to try to estimate surface better
        num_new_points are resampled randomly
        The rest (num_points - num_new_points) are sampled close to points with higher opacity
        """
        new_points_one = torch.rand((self.num_new_points, 3,))
        prb = np.log(np.clip(self.density.cpu().detach().numpy(), 0.0001, None)+1)
        prb = prb / prb.sum()
        new_points_two_id = np.random.choice(np.arange(self.num_points), size=self.num_points - self.num_new_points, p=prb)
        new_points_two = self.points[new_points_two_id,:] + 0.1 * torch.randn_like(self.points[new_points_two_id,:])
        new_points = torch.concat([new_points_one, new_points_two], dim=0)
        distances, indices = self.nbrs.kneighbors(new_points.numpy())
        weights = torch.exp(-40.0*torch.Tensor(distances)**2)
        density = torch.sum(weights * self.density[indices].detach(), dim=1) / torch.sum(weights, dim=1)
        sh_coef = torch.sum(weights[:,:,None,None] * self.sh_coef[indices].detach(), dim=1) / torch.sum(weights[:,:,None,None], dim=1)
        self.points = new_points
        self.density = density
        self.sh_coef = sh_coef
        self.nbrs = NearestNeighbors(n_neighbors=self.nn+1, algorithm='ball_tree').fit(self.points.numpy())
        distances, indices = self.nbrs.kneighbors(self.points.numpy())
        indices = torch.LongTensor(indices)
        self.triangles = torch.concat(
            [
                torch.floor(torch.arange(self.nn*self.num_points)/self.nn).long().unsqueeze(-1),
                indices[:,1:].flatten().unsqueeze(-1),
                torch.roll(indices[:,1:], 1, dims=-1).flatten().unsqueeze(-1)
            ],
            dim=-1
        )
    
    def get_topk(self, rays, centers, triangles):
        """
        Run MT algorithm in parallel and sort to find closest ~100 triangles to each ray
        Returns indices of triangles hit, and t (distance to each triangle), and u, v (barycentric coordinates of intersection)
        Note that t, u, v are returned sorted so no need to use indices to index it
        rays : FloatTensor(BS, 3)
        centers : FloatTensor(BS, 3)
        triangles : LongTensor(m, 3)
        """
        if triangles.size(0) == 0:
            return 0
        phi = torch.arccos(rays[:,2])
        sinphi = torch.sin(phi)
        sinphi[torch.abs(sinphi) < 1e-9] += 1e-8
        theta = torch.arccos(torch.clamp(rays[:,0]/sinphi, -1.0, 1.0))
        e1 = self.points[triangles[:,1],:]-self.points[triangles[:,0],:]
        e2 = self.points[triangles[:,2],:]-self.points[triangles[:,0],:]
        s = centers[:,None,:] - self.points[triangles[:,0]][None,:,:]
        p = torch.cross(
            rays[:,None,:],
            e2[None,:,:],
        dim=-1)
        q = torch.cross(
            s,
            e1[None,:,:],
        dim=-1)
        det = torch.sum(p * e1[None,:,:], dim=-1)
        det[torch.abs(det) < 1e-9] = 1e-9
        t = torch.sum(q * e2[None,:,:], dim=-1)/det
        u = torch.sum(p * s, dim=-1)/det
        v = torch.sum(q * rays[:,None,:], dim=-1)/det
        mask = torch.logical_and(t > 0.01, t < 1000.0)
        mask = torch.logical_and(mask, u >= 0)
        mask = torch.logical_and(mask, v >= 0)
        mask = torch.logical_and(mask, u+v <= 1)
        t[~mask] = 2000.0
        tk = min(self.tk, triangles.size(0))
        idx = torch.topk(t, tk, largest=False, dim=-1).indices
        li = torch.arange(idx.size(0)).unsqueeze(-1).repeat((1, idx.size(1)))
        return idx, t[li,idx], u[li,idx], v[li,idx]
    def local_render(self, rays, centers, triangles, topk, t, u, v):
        """
        Uses the transmittance formula with the first 100 hit triangles already known, as well as t, u, v, computed from above function
        Outputs RGBA color expected to be seen from each ray
        rays : FloatTensor(BS, 3)
        centers : FloatTensor(BS, 3)
        triangles : LongTensor(m, 3)
        topk : LongTensor(BS, tk=100)
        t : FloatTensor(BS, tk=100)
        u : FloatTensor(BS, tk=100)
        v : FloatTensor(BS, tk=100)
        """
        if triangles.size(0) == 0:
            return 0
        phi = torch.arccos(rays[:,2])
        sinphi = torch.sin(phi)
        sinphi[torch.abs(sinphi) < 1e-9] += 1e-8
        theta = torch.arccos(torch.clamp(rays[:,0]/sinphi, -1.0, 1.0))
        mask = torch.logical_and(t > 0.01, t < 1000.0)
        mask = torch.logical_and(mask, u >= 0)
        mask = torch.logical_and(mask, v >= 0)
        mask = torch.logical_and(mask, u+v <= 1)
        t[~mask] = 2000.0
        tk = min(self.tk, triangles.size(0))
        idx = topk
        li = torch.arange(idx.size(0)).unsqueeze(-1).repeat((1, idx.size(1)))
        sigma = (1 - u - v) * self.density[triangles[idx][:,:,0]] + u * self.density[triangles[idx][:,:,1]] + v * self.density[triangles[idx][:,:,2]]
        sigma = sigma * mask.float()
        sigma = torch.clip(sigma, 0.0001, 20.0)
        csigma = torch.cumsum(sigma, dim=-1)
        csigma = torch.clip(csigma, 0.0001, 20.0)
        temp = torch.exp(-csigma)
        tm = torch.zeros_like(temp)
        tm[:,1:] = temp[:,:-1] - temp[:,1:]
        tm[:,0] = 1 - temp[:,0]
        tm = tm * mask.float()
        
        local_sh = (1 - u - v)[:,:,None,None] * self.sh_coef[triangles[idx][:,:,0],:,:] + u[:,:,None,None] * self.sh_coef[triangles[idx][:,:,1],:,:] + v[:,:,None,None] * self.sh_coef[triangles[idx][:,:,2],:,:]
        local_sh = local_sh * tm[:,:,None,None]
        local_sh = torch.sum(local_sh, dim=-3)
        rgba = torch.zeros((rays.size(0),4), dtype=torch.cfloat, device=rays.device)
        rgba[:,3] = tm.sum(dim=-1)
        for i in range(self.k):
            temp_sh = get_spherical_harmonics(i, theta.to(torch.device("cpu")), phi.to(torch.device("cpu"))).to(rays.device)
            rgba[:,:3] += torch.sum(temp_sh[:,None,:] * local_sh[:,:,self.iv[i]:self.iv[i+1]], dim=-1)
        rgba = rgba.real
        return rgba
    def render(self, rays, centers):
        """
        Glues get_topk and local_render to get RGBA colors expected to be seen from each ray
        rays : FloatTensor(BS, 3)
        centers : FloatTensor(BS, 3)
        """
        rays = rays[:,:3]
        centers = centers[:,:3]
        topk, t, u, v = self.get_topk(rays, centers, self.triangles)
        rgba = self.local_render(rays, centers, self.triangles, topk, t, u, v)
        return rgba, topk, t, u, v
    def rasterize(self, rays, centers, pid):
        """
        An attempt to accelerate the rendering process by doing it column by column of the images,
        and only looking for intersections for triangles which intersect these columns
        Need to fix camera.get_pixels first though
        """
        pixels = self.pixels[pid]
        bbl = torch.min(pixels[:,self.triangles,:], dim=-2)[0]
        bbr = torch.max(pixels[:,self.triangles,:], dim=-2)[0]
        rgba = torch.zeros_like(rays)
        for wid in range(rays.size(-2)):
            triangles_mask = torch.logical_and(wid >= bbl[:,:,1], wid <= bbr[:,:,1])
            topk, t, u, v = self.get_topk(rays.squeeze(0)[:,wid,:3], centers.squeeze(0)[:,wid,:3], self.triangles[triangles_mask.squeeze(0)])
            rgba[:,wid,:] = self.local_render(rays.squeeze(0)[:,wid,:3], centers.squeeze(0)[:,wid,:3], self.triangles[triangles_mask.squeeze(0)], topk, t, u, v)
        return rgba
# Functions to test/debug intermediate steps
def test_loader():
    imgs, poses, _, [H, W, focal], i_split = load_blender_data("data/lego_small")
    print(imgs[0,:,:,:].min(), imgs[0,:,:,:].max())
    print(imgs.shape)
    print(poses[0,:,:])
    print(poses.shape)
    print(H, W, focal)
    print(i_split)
def test_rays():
    imgs, poses, _, [H, W, focal], i_split = load_blender_data("data/lego_small")
    camera = Camera(H, W, focal, torch.Tensor(poses), torch.Tensor(imgs))
    rays, centers = camera.get_rays((H, W))
    rays = rays.reshape(poses.shape[0], H, W, -1)
    centers = centers.reshape(poses.shape[0], H, W, -1)
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    for t, c in [(0.25, 'r'), (0.5, 'g'), (1.0, 'b')]:
        xs = centers[:,1,:,0] + t * rays[:,1,:,0]
        ys = centers[:,1,:,1] + t * rays[:,1,:,1]
        zs = centers[:,1,:,2] + t * rays[:,1,:,2]
        xs = xs.reshape(-1, 3)
        ys = ys.reshape(-1, 3)
        zs = zs.reshape(-1, 3)
        ax.scatter(xs, ys, zs, c=c)
    for t, c in [(0.25, 'pink'), (0.5, 'orange'), (1.0, 'yellow')]:
        xs = centers[:,:,-1,0] + t * rays[:,:,-1,0]
        ys = centers[:,:,-1,1] + t * rays[:,:,-1,1]
        zs = centers[:,:,-1,2] + t * rays[:,:,-1,2]
        xs = xs.reshape(-1, 3)
        ys = ys.reshape(-1, 3)
        zs = zs.reshape(-1, 3)
        ax.scatter(xs, ys, zs, c=c)
    ax.set_xlabel('X Label')
    ax.set_ylabel('Y Label')
    ax.set_zlabel('Z Label')
    plt.show()
def test_get_pixels():
    points = torch.rand((1000, 3,))
    imgs, poses, _, [H, W, focal], i_split = load_blender_data("data/lego_small")
    camera = Camera(H, W, focal, torch.Tensor(poses), torch.Tensor(imgs))
    pixels = camera.get_pixels(points)
    print(points.shape, pixels.shape)
    print(points[:5,:])
    print(pixels[:,:5,:])
    
def test_dataset():
    imgs, poses, _, [H, W, focal], i_split = load_blender_data("data/lego_small")
    camera = Camera(H, W, focal, torch.Tensor(poses), torch.Tensor(imgs))
    dataloader = DataLoader(camera, batch_size=256, shuffle=True, num_workers=0)
    for batch in dataloader:
        print(len(batch))
        rays, centers, _, _ = batch
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        for t, c in [(0.25, 'r'), (0.5, 'g'), (1.0, 'b')]:
            xs = centers[:,0] + t * rays[:,0]
            ys = centers[:,1] + t * rays[:,1]
            zs = centers[:,2] + t * rays[:,2]
            ax.scatter(xs, ys, zs, c=c)
        ax.set_xlabel('X Label')
        ax.set_ylabel('Y Label')
        ax.set_zlabel('Z Label')
        plt.show()
        break
def test_triangle_soup_render():
    imgs, poses, _, [H, W, focal], i_split = load_blender_data("data/lego_small")
    H = H//10
    W = W//10
    focal = focal / 10
    camera = Camera(H, W, focal, torch.Tensor(poses), torch.Tensor(imgs))
    rays, centers = camera.get_rays((H, W))
    rays = rays.reshape(-1, 4)
    centers = centers.reshape(-1, 4)
    scene = TriangleSoup(20)
    rgba, topk, t, u, v = scene.render(rays, centers)
    for i in tqdm(range(10)):
        rgba = scene.local_render(rays[:,:3], centers[:,:3], scene.triangles, topk, t, u, v)
    rgba[:,:3] -= torch.min(rgba[:,:3], dim=-1)[0][:,None]
    rgba[:,:3] /= torch.max(rgba[:,:3], dim=-1)[0][:,None]
    print(rgba.size(), "rgba")
    rgba = rgba.reshape(camera.poses.size(0), H, W, 4)
    rgba = rgba[2,:,:,:].numpy()
    plt.matshow(rgba)
    plt.show()
test_loader()
0.0 1.0 (3, 800, 800, 4) [[-9.9990219e-01 4.1922452e-03 -1.3345719e-02 -5.3798322e-02] [-1.3988681e-02 -2.9965907e-01 9.5394367e-01 3.8454704e+00] [-4.6566129e-10 9.5403719e-01 2.9968831e-01 1.2080823e+00] [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 1.0000000e+00]] (3, 4, 4) 800 800 1111.1110311937682 [array([0]), array([1]), array([2])]
test_rays()
c:\Users\Victor\anaconda3\envs\minimal\lib\site-packages\torch\functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\TensorShape.cpp:2228.) return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
test_get_pixels()
torch.Size([1000, 3]) torch.Size([3, 1000, 2])
tensor([[0.4963, 0.7682, 0.0885],
        [0.1320, 0.3074, 0.6341],
        [0.4901, 0.8964, 0.4556],
        [0.6323, 0.3489, 0.4017],
        [0.0223, 0.1689, 0.2939]])
tensor([[[571.1018, 448.1463],
         [442.1175, 239.0049],
         [582.6046, 338.2523],
         [596.6450, 312.4406],
         [406.7428, 332.0547]],
        [[479.9893, 646.1439],
         [414.5891, 467.2859],
         [475.9926, 690.2575],
         [561.4541, 530.6998],
         [392.5014, 432.2550]],
        [[559.3786, 561.8347],
         [443.1466, 331.7718],
         [575.6266, 524.8148],
         [600.0204, 394.2191],
         [406.1947, 376.8306]]])
test_dataset()
4
test_triangle_soup_render()
100%|██████████| 10/10 [00:26<00:00, 2.62s/it]
torch.Size([19200, 4]) rgba
def train(dataloader, scene, optimizer, device, canvas_size):
    canvas = torch.zeros(canvas_size, device=device)
    losses = []
    for i, batch in enumerate(tqdm(dataloader)):
        rays, centers, imgs, pixels = batch
        rays = rays.to(device)
        centers = centers.to(device)
        imgs = imgs.to(device)
        pixels = pixels.to(device)
        # rgba = scene.rasterize(rays, centers, pixels[:,0])
        rgba, topk, t, u, v = scene.render(rays, centers)
        test = (rgba != rgba)
        if test.sum() != 0:
            print("oh no")
        loss = nn.MSELoss()(rgba, imgs)
        canvas[pixels[:,0],pixels[:,1],pixels[:,2],:] = rgba.detach()
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_([scene.density, scene.sh_coef], 10.0)
        optimizer.step()
        losses.append(loss.item())
        if i % 100 == 0:
            print(f"loss: {loss.item():7f}")
    return canvas, losses
def eval(dataloader, scene, device, canvas_size):
    canvas = torch.zeros(canvas_size, device=device)
    losses = []
    mse = 0.0
    for i, batch in enumerate(tqdm(dataloader)):
        rays, centers, imgs, pixels = batch
        rays = rays.to(device)
        centers = centers.to(device)
        imgs = imgs.to(device)
        pixels = pixels.to(device)
        rgba, topk, t, u, v = scene.render(rays, centers)
        test = (rgba != rgba)
        if test.sum() != 0:
            print("oh no")
        mse += torch.sum((rgba - imgs)**2).item()
        canvas[pixels[:,0],pixels[:,1],pixels[:,2],:] = rgba.detach()
    mse /= (canvas.size(0) * canvas.size(1) * canvas.size(2) * canvas.size(3))
    psnr = 10 * np.log10(1/mse)
    return canvas, mse, psnr
def main(config_file):
    with open(config_file, "r") as f:
        cfg = yaml.safe_load(f)
    print(cfg)
    set_seed(cfg["seed"])
    device = torch.device("cuda" if torch.cuda.is_available() and cfg["device"] == "cuda" else "cpu")
    print(device)
    imgs, poses, _, [H, W, focal], i_split = load_blender_data(os.path.join("data", cfg["dataset"]))
    H = H // cfg["downsample_factor"]
    W = W // cfg["downsample_factor"]
    focal = focal / cfg["downsample_factor"]
    imgs = imgs[:,::cfg["downsample_factor"],::cfg["downsample_factor"],:]
    log_dir = os.path.join("logs", cfg["name"])
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    poses = torch.Tensor(poses)
    imgs = torch.Tensor(imgs)
    print(imgs.size())
    train_camera = Camera(H, W, focal, poses[i_split[0]], imgs[i_split[0]])
    val_camera = Camera(H, W, focal, poses[i_split[1]], imgs[i_split[1]])
    scene = TriangleSoup(cfg["model"]["num_vertices"], cfg["model"]["knn"], cfg["model"]["sph_degree"])
    fig = plt.figure()
    scene.to(device)
    scene.points.requires_grad = True
    scene.density.requires_grad = True
    scene.sh_coef.requires_grad = True
    train_dataloader = DataLoader(train_camera, batch_size=cfg["training"]["batch_size"], shuffle=True, num_workers=0, drop_last=True)
    optimizer = torch.optim.Adam([scene.points, scene.density, scene.sh_coef], lr=cfg["training"]["lr"])
    val_dataloader =  DataLoader(val_camera, batch_size=cfg["evaluation"]["batch_size"], shuffle=True, num_workers=0, drop_last=True)
    for epoch in tqdm(range(cfg["training"]["epochs"])):
        canvas, losses = train(train_dataloader, scene, optimizer, device, (len(i_split[0]), H, W, 4))
        canvas = canvas.cpu().detach().numpy()
        epoch_dir = os.path.join(log_dir, f"epoch_{epoch:02}")
        if not os.path.exists(epoch_dir):
            os.makedirs(epoch_dir)
        train_dir = os.path.join(epoch_dir, "train")
        if not os.path.exists(train_dir):
            os.makedirs(train_dir)
        val_dir = os.path.join(epoch_dir, "val")
        if not os.path.exists(val_dir):
            os.makedirs(val_dir)
        
        scene.save(os.path.join(epoch_dir, "model.npz"))
        np.save(os.path.join(epoch_dir, "render.npy"), canvas)
        with open(os.path.join(epoch_dir, "losses.pkl"), "wb") as f:
            pickle.dump(losses, f)
        for i in range(canvas.shape[0]):
            plt.imsave(os.path.join(train_dir, f"render_{i:02}.png"), np.clip(canvas[i], 0, 1))
        
        canvas, mse, psnr = eval(val_dataloader, scene, device, (len(i_split[1]), H, W, 4))
        canvas = canvas.cpu().detach().numpy()
        with open(os.path.join(epoch_dir, "val.pkl"), "wb") as f:
            pickle.dump([mse, psnr], f)
        for i in range(canvas.shape[0]):
            canvas[i,:,:,:3] -= canvas[i,:,:,:3].min()
            canvas[i,:,:,:3] /= canvas[i,:,:,:3].max()
            plt.imsave(os.path.join(val_dir, f"render_{i:02}.png"), np.clip(canvas[i], 0, 1))
config = "configs/lego_test.yaml"
main(config)
{'name': 'lego_test', 'device': 'cuda', 'seed': 100, 'dataset': 'lego', 'downsample_factor': 16, 'model': {'num_vertices': 5000, 'knn': 5, 'sph_degree': 2}, 'training': {'epochs': 10, 'batch_size': 200, 'lr': 0.0005}, 'evaluation': {'batch_size': 200}}
cuda
torch.Size([400, 50, 50, 4])
c:\Users\Victor\anaconda3\envs\minimal\lib\site-packages\torch\functional.py:568: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\TensorShape.cpp:2228.) return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] 0%| | 0/10 [00:00<?, ?it/s]
loss: 0.107788
loss: 0.098396
loss: 0.095989
loss: 0.094471
loss: 0.075324
loss: 0.083629
loss: 0.067422
loss: 0.068058
loss: 0.074677
loss: 0.069127
loss: 0.056110
loss: 0.057298
loss: 0.050010
100%|██████████| 1250/1250 [00:56<00:00, 22.10it/s] 100%|██████████| 1250/1250 [00:29<00:00, 42.15it/s] 10%|█ | 1/10 [01:28<13:17, 88.63s/it]
loss: 0.045145
loss: 0.039960
loss: 0.046379
loss: 0.036461
loss: 0.035219
loss: 0.041457
loss: 0.043703
loss: 0.036338
loss: 0.034461
loss: 0.026351
loss: 0.031880
loss: 0.034106
loss: 0.029560
100%|██████████| 1250/1250 [01:00<00:00, 20.77it/s] 100%|██████████| 1250/1250 [00:30<00:00, 41.23it/s] 20%|██ | 2/10 [02:59<11:59, 89.99s/it]
loss: 0.025219
loss: 0.036270
loss: 0.027327
loss: 0.026023
loss: 0.025894
loss: 0.028587
loss: 0.025234
loss: 0.030776
loss: 0.028220
loss: 0.016638
loss: 0.016646
loss: 0.019213
loss: 0.021666
100%|██████████| 1250/1250 [00:56<00:00, 22.00it/s] 100%|██████████| 1250/1250 [00:28<00:00, 43.97it/s] 30%|███ | 3/10 [04:25<10:16, 88.01s/it]
loss: 0.020747
loss: 0.027556
loss: 0.020424
loss: 0.017833
loss: 0.021749
loss: 0.023605
loss: 0.020714
loss: 0.020136
loss: 0.013131
loss: 0.015171
loss: 0.019533
loss: 0.027280
loss: 0.017268
100%|██████████| 1250/1250 [00:57<00:00, 21.90it/s] 100%|██████████| 1250/1250 [00:28<00:00, 43.41it/s] 40%|████ | 4/10 [05:51<08:43, 87.33s/it]
loss: 0.023692
loss: 0.015674
loss: 0.017008
loss: 0.022578
loss: 0.018253
loss: 0.026133
loss: 0.014525
loss: 0.017504
loss: 0.013242
loss: 0.024178
loss: 0.011930
loss: 0.017573
loss: 0.018462
100%|██████████| 1250/1250 [00:57<00:00, 21.79it/s] 100%|██████████| 1250/1250 [00:28<00:00, 44.53it/s] 50%|█████ | 5/10 [07:17<07:13, 86.79s/it]
loss: 0.013991
loss: 0.013139
loss: 0.012921
loss: 0.012921
loss: 0.019782
loss: 0.016692
loss: 0.013594
56%|█████▋ | 705/1250 [00:32<00:24, 22.24it/s]
loss: 0.022661
loss: 0.016196
loss: 0.014479
loss: 0.013854
loss: 0.017427
loss: 0.019575
100%|██████████| 1250/1250 [00:57<00:00, 21.90it/s] 100%|██████████| 1250/1250 [00:28<00:00, 43.28it/s] 60%|██████ | 6/10 [08:43<05:46, 86.64s/it]
loss: 0.008935
loss: 0.021477
loss: 0.017324
loss: 0.019286
loss: 0.019286
loss: 0.020070
loss: 0.017268
loss: 0.020512
loss: 0.010202
loss: 0.016491
loss: 0.019459
loss: 0.017550
loss: 0.019132
100%|██████████| 1250/1250 [00:54<00:00, 22.85it/s] 100%|██████████| 1250/1250 [00:26<00:00, 47.11it/s] 70%|███████ | 7/10 [10:05<04:14, 84.98s/it]
loss: 0.017064
loss: 0.012955
loss: 0.013177
loss: 0.020631
loss: 0.016104
loss: 0.014560
loss: 0.020707
loss: 0.007977
loss: 0.018706
loss: 0.012811
loss: 0.014453
loss: 0.019291
loss: 0.014513
100%|██████████| 1250/1250 [00:53<00:00, 23.47it/s] 100%|██████████| 1250/1250 [00:26<00:00, 46.35it/s] 80%|████████ | 8/10 [11:25<02:47, 83.58s/it]
loss: 0.018958
loss: 0.014288
loss: 0.021307
loss: 0.019521
loss: 0.017193
loss: 0.013569
loss: 0.015859
loss: 0.012770
loss: 0.019757
loss: 0.017051
loss: 0.016520
loss: 0.016820
loss: 0.016842
100%|██████████| 1250/1250 [00:53<00:00, 23.31it/s] 100%|██████████| 1250/1250 [00:26<00:00, 46.76it/s] 90%|█████████ | 9/10 [12:46<01:22, 82.68s/it]
loss: 0.014662
loss: 0.026067
loss: 0.015909
loss: 0.018041
loss: 0.015095
loss: 0.023086
loss: 0.017087
loss: 0.020385
loss: 0.022877
loss: 0.014520
loss: 0.019186
loss: 0.018283
loss: 0.012703
100%|██████████| 1250/1250 [00:53<00:00, 23.33it/s] 100%|██████████| 1250/1250 [00:26<00:00, 47.02it/s] 100%|██████████| 10/10 [14:07<00:00, 84.70s/it]
<Figure size 640x480 with 0 Axes>